xref: /linux/net/netfilter/ipset/ip_set_core.c (revision d5a7fc58da039903b332041e8c67daae36f08b50)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
3  *                         Patrick Schaaf <bof@bof.de>
4  * Copyright (C) 2003-2013 Jozsef Kadlecsik <kadlec@netfilter.org>
5  */
6 
7 /* Kernel module for IP set management */
8 
9 #include <linux/init.h>
10 #include <linux/module.h>
11 #include <linux/moduleparam.h>
12 #include <linux/ip.h>
13 #include <linux/skbuff.h>
14 #include <linux/spinlock.h>
15 #include <linux/rculist.h>
16 #include <net/netlink.h>
17 #include <net/net_namespace.h>
18 #include <net/netns/generic.h>
19 
20 #include <linux/netfilter.h>
21 #include <linux/netfilter/x_tables.h>
22 #include <linux/netfilter/nfnetlink.h>
23 #include <linux/netfilter/ipset/ip_set.h>
24 
25 static LIST_HEAD(ip_set_type_list);		/* all registered set types */
26 static DEFINE_MUTEX(ip_set_type_mutex);		/* protects ip_set_type_list */
27 static DEFINE_RWLOCK(ip_set_ref_lock);		/* protects the set refs */
28 
29 struct ip_set_net {
30 	struct ip_set * __rcu *ip_set_list;	/* all individual sets */
31 	ip_set_id_t	ip_set_max;	/* max number of sets */
32 	bool		is_deleted;	/* deleted by ip_set_net_exit */
33 	bool		is_destroyed;	/* all sets are destroyed */
34 };
35 
36 static unsigned int ip_set_net_id __read_mostly;
37 
ip_set_pernet(struct net * net)38 static struct ip_set_net *ip_set_pernet(struct net *net)
39 {
40 	return net_generic(net, ip_set_net_id);
41 }
42 
43 #define IP_SET_INC	64
44 #define STRNCMP(a, b)	(strncmp(a, b, IPSET_MAXNAMELEN) == 0)
45 
46 static unsigned int max_sets;
47 
48 module_param(max_sets, int, 0600);
49 MODULE_PARM_DESC(max_sets, "maximal number of sets");
50 MODULE_LICENSE("GPL");
51 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@netfilter.org>");
52 MODULE_DESCRIPTION("core IP set support");
53 MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
54 
55 /* When the nfnl mutex or ip_set_ref_lock is held: */
56 #define ip_set_dereference(inst)	\
57 	rcu_dereference_protected((inst)->ip_set_list,	\
58 		lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET) || \
59 		lockdep_is_held(&ip_set_ref_lock) || \
60 		(inst)->is_deleted)
61 #define ip_set(inst, id)		\
62 	ip_set_dereference(inst)[id]
63 #define ip_set_ref_netlink(inst,id)	\
64 	rcu_dereference_raw((inst)->ip_set_list)[id]
65 #define ip_set_dereference_nfnl(p)	\
66 	rcu_dereference_check(p, lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET))
67 
68 /* The set types are implemented in modules and registered set types
69  * can be found in ip_set_type_list. Adding/deleting types is
70  * serialized by ip_set_type_mutex.
71  */
72 
73 static void
ip_set_type_lock(void)74 ip_set_type_lock(void)
75 {
76 	mutex_lock(&ip_set_type_mutex);
77 }
78 
79 static void
ip_set_type_unlock(void)80 ip_set_type_unlock(void)
81 {
82 	mutex_unlock(&ip_set_type_mutex);
83 }
84 
85 /* Register and deregister settype */
86 
87 static struct ip_set_type *
find_set_type(const char * name,u8 family,u8 revision)88 find_set_type(const char *name, u8 family, u8 revision)
89 {
90 	struct ip_set_type *type;
91 
92 	list_for_each_entry_rcu(type, &ip_set_type_list, list,
93 				lockdep_is_held(&ip_set_type_mutex))
94 		if (STRNCMP(type->name, name) &&
95 		    (type->family == family ||
96 		     type->family == NFPROTO_UNSPEC) &&
97 		    revision >= type->revision_min &&
98 		    revision <= type->revision_max)
99 			return type;
100 	return NULL;
101 }
102 
103 /* Unlock, try to load a set type module and lock again */
104 static bool
load_settype(const char * name)105 load_settype(const char *name)
106 {
107 	nfnl_unlock(NFNL_SUBSYS_IPSET);
108 	pr_debug("try to load ip_set_%s\n", name);
109 	if (request_module("ip_set_%s", name) < 0) {
110 		pr_warn("Can't find ip_set type %s\n", name);
111 		nfnl_lock(NFNL_SUBSYS_IPSET);
112 		return false;
113 	}
114 	nfnl_lock(NFNL_SUBSYS_IPSET);
115 	return true;
116 }
117 
118 /* Find a set type and reference it */
119 #define find_set_type_get(name, family, revision, found)	\
120 	__find_set_type_get(name, family, revision, found, false)
121 
122 static int
__find_set_type_get(const char * name,u8 family,u8 revision,struct ip_set_type ** found,bool retry)123 __find_set_type_get(const char *name, u8 family, u8 revision,
124 		    struct ip_set_type **found, bool retry)
125 {
126 	struct ip_set_type *type;
127 	int err;
128 
129 	if (retry && !load_settype(name))
130 		return -IPSET_ERR_FIND_TYPE;
131 
132 	rcu_read_lock();
133 	*found = find_set_type(name, family, revision);
134 	if (*found) {
135 		err = !try_module_get((*found)->me) ? -EFAULT : 0;
136 		goto unlock;
137 	}
138 	/* Make sure the type is already loaded
139 	 * but we don't support the revision
140 	 */
141 	list_for_each_entry_rcu(type, &ip_set_type_list, list)
142 		if (STRNCMP(type->name, name)) {
143 			err = -IPSET_ERR_FIND_TYPE;
144 			goto unlock;
145 		}
146 	rcu_read_unlock();
147 
148 	return retry ? -IPSET_ERR_FIND_TYPE :
149 		__find_set_type_get(name, family, revision, found, true);
150 
151 unlock:
152 	rcu_read_unlock();
153 	return err;
154 }
155 
156 /* Find a given set type by name and family.
157  * If we succeeded, the supported minimal and maximum revisions are
158  * filled out.
159  */
160 #define find_set_type_minmax(name, family, min, max) \
161 	__find_set_type_minmax(name, family, min, max, false)
162 
163 static int
__find_set_type_minmax(const char * name,u8 family,u8 * min,u8 * max,bool retry)164 __find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max,
165 		       bool retry)
166 {
167 	struct ip_set_type *type;
168 	bool found = false;
169 
170 	if (retry && !load_settype(name))
171 		return -IPSET_ERR_FIND_TYPE;
172 
173 	*min = 255; *max = 0;
174 	rcu_read_lock();
175 	list_for_each_entry_rcu(type, &ip_set_type_list, list)
176 		if (STRNCMP(type->name, name) &&
177 		    (type->family == family ||
178 		     type->family == NFPROTO_UNSPEC)) {
179 			found = true;
180 			if (type->revision_min < *min)
181 				*min = type->revision_min;
182 			if (type->revision_max > *max)
183 				*max = type->revision_max;
184 		}
185 	rcu_read_unlock();
186 	if (found)
187 		return 0;
188 
189 	return retry ? -IPSET_ERR_FIND_TYPE :
190 		__find_set_type_minmax(name, family, min, max, true);
191 }
192 
193 #define family_name(f)	((f) == NFPROTO_IPV4 ? "inet" : \
194 			 (f) == NFPROTO_IPV6 ? "inet6" : "any")
195 
196 /* Register a set type structure. The type is identified by
197  * the unique triple of name, family and revision.
198  */
199 int
ip_set_type_register(struct ip_set_type * type)200 ip_set_type_register(struct ip_set_type *type)
201 {
202 	int ret = 0;
203 
204 	if (type->protocol != IPSET_PROTOCOL) {
205 		pr_warn("ip_set type %s, family %s, revision %u:%u uses wrong protocol version %u (want %u)\n",
206 			type->name, family_name(type->family),
207 			type->revision_min, type->revision_max,
208 			type->protocol, IPSET_PROTOCOL);
209 		return -EINVAL;
210 	}
211 
212 	ip_set_type_lock();
213 	if (find_set_type(type->name, type->family, type->revision_min)) {
214 		/* Duplicate! */
215 		pr_warn("ip_set type %s, family %s with revision min %u already registered!\n",
216 			type->name, family_name(type->family),
217 			type->revision_min);
218 		ip_set_type_unlock();
219 		return -EINVAL;
220 	}
221 	list_add_rcu(&type->list, &ip_set_type_list);
222 	pr_debug("type %s, family %s, revision %u:%u registered.\n",
223 		 type->name, family_name(type->family),
224 		 type->revision_min, type->revision_max);
225 	ip_set_type_unlock();
226 
227 	return ret;
228 }
229 EXPORT_SYMBOL_GPL(ip_set_type_register);
230 
231 /* Unregister a set type. There's a small race with ip_set_create */
232 void
ip_set_type_unregister(struct ip_set_type * type)233 ip_set_type_unregister(struct ip_set_type *type)
234 {
235 	ip_set_type_lock();
236 	if (!find_set_type(type->name, type->family, type->revision_min)) {
237 		pr_warn("ip_set type %s, family %s with revision min %u not registered\n",
238 			type->name, family_name(type->family),
239 			type->revision_min);
240 		ip_set_type_unlock();
241 		return;
242 	}
243 	list_del_rcu(&type->list);
244 	pr_debug("type %s, family %s with revision min %u unregistered.\n",
245 		 type->name, family_name(type->family), type->revision_min);
246 	ip_set_type_unlock();
247 
248 	synchronize_rcu();
249 }
250 EXPORT_SYMBOL_GPL(ip_set_type_unregister);
251 
252 /* Utility functions */
253 void *
ip_set_alloc(size_t size)254 ip_set_alloc(size_t size)
255 {
256 	return kvzalloc(size, GFP_KERNEL_ACCOUNT);
257 }
258 EXPORT_SYMBOL_GPL(ip_set_alloc);
259 
260 void
ip_set_free(void * members)261 ip_set_free(void *members)
262 {
263 	pr_debug("%p: free with %s\n", members,
264 		 is_vmalloc_addr(members) ? "vfree" : "kfree");
265 	kvfree(members);
266 }
267 EXPORT_SYMBOL_GPL(ip_set_free);
268 
269 static bool
flag_nested(const struct nlattr * nla)270 flag_nested(const struct nlattr *nla)
271 {
272 	return nla->nla_type & NLA_F_NESTED;
273 }
274 
275 static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
276 	[IPSET_ATTR_IPADDR_IPV4]	= { .type = NLA_U32 },
277 	[IPSET_ATTR_IPADDR_IPV6]	= NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
278 };
279 
280 int
ip_set_get_ipaddr4(struct nlattr * nla,__be32 * ipaddr)281 ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
282 {
283 	struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
284 
285 	if (unlikely(!flag_nested(nla)))
286 		return -IPSET_ERR_PROTOCOL;
287 	if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla,
288 			     ipaddr_policy, NULL))
289 		return -IPSET_ERR_PROTOCOL;
290 	if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
291 		return -IPSET_ERR_PROTOCOL;
292 
293 	*ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
294 	return 0;
295 }
296 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
297 
298 int
ip_set_get_ipaddr6(struct nlattr * nla,union nf_inet_addr * ipaddr)299 ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
300 {
301 	struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
302 
303 	if (unlikely(!flag_nested(nla)))
304 		return -IPSET_ERR_PROTOCOL;
305 
306 	if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla,
307 			     ipaddr_policy, NULL))
308 		return -IPSET_ERR_PROTOCOL;
309 	if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
310 		return -IPSET_ERR_PROTOCOL;
311 
312 	memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
313 	       sizeof(struct in6_addr));
314 	return 0;
315 }
316 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
317 
318 static u32
ip_set_timeout_get(const unsigned long * timeout)319 ip_set_timeout_get(const unsigned long *timeout)
320 {
321 	u32 t;
322 
323 	if (*timeout == IPSET_ELEM_PERMANENT)
324 		return 0;
325 
326 	t = jiffies_to_msecs(*timeout - jiffies) / MSEC_PER_SEC;
327 	/* Zero value in userspace means no timeout */
328 	return t == 0 ? 1 : t;
329 }
330 
331 static char *
ip_set_comment_uget(struct nlattr * tb)332 ip_set_comment_uget(struct nlattr *tb)
333 {
334 	return nla_data(tb);
335 }
336 
337 /* Called from uadd only, protected by the set spinlock.
338  * The kadt functions don't use the comment extensions in any way.
339  */
340 void
ip_set_init_comment(struct ip_set * set,struct ip_set_comment * comment,const struct ip_set_ext * ext)341 ip_set_init_comment(struct ip_set *set, struct ip_set_comment *comment,
342 		    const struct ip_set_ext *ext)
343 {
344 	struct ip_set_comment_rcu *c = rcu_dereference_protected(comment->c, 1);
345 	size_t len = ext->comment ? strlen(ext->comment) : 0;
346 
347 	if (unlikely(c)) {
348 		set->ext_size -= sizeof(*c) + strlen(c->str) + 1;
349 		kfree_rcu(c, rcu);
350 		rcu_assign_pointer(comment->c, NULL);
351 	}
352 	if (!len)
353 		return;
354 	if (unlikely(len > IPSET_MAX_COMMENT_SIZE))
355 		len = IPSET_MAX_COMMENT_SIZE;
356 	c = kmalloc(sizeof(*c) + len + 1, GFP_ATOMIC);
357 	if (unlikely(!c))
358 		return;
359 	strscpy(c->str, ext->comment, len + 1);
360 	set->ext_size += sizeof(*c) + strlen(c->str) + 1;
361 	rcu_assign_pointer(comment->c, c);
362 }
363 EXPORT_SYMBOL_GPL(ip_set_init_comment);
364 
365 /* Used only when dumping a set, protected by rcu_read_lock() */
366 static int
ip_set_put_comment(struct sk_buff * skb,const struct ip_set_comment * comment)367 ip_set_put_comment(struct sk_buff *skb, const struct ip_set_comment *comment)
368 {
369 	struct ip_set_comment_rcu *c = rcu_dereference(comment->c);
370 
371 	if (!c)
372 		return 0;
373 	return nla_put_string(skb, IPSET_ATTR_COMMENT, c->str);
374 }
375 
376 /* Called from uadd/udel, flush or the garbage collectors protected
377  * by the set spinlock.
378  * Called when the set is destroyed and when there can't be any user
379  * of the set data anymore.
380  */
381 static void
ip_set_comment_free(struct ip_set * set,void * ptr)382 ip_set_comment_free(struct ip_set *set, void *ptr)
383 {
384 	struct ip_set_comment *comment = ptr;
385 	struct ip_set_comment_rcu *c;
386 
387 	c = rcu_dereference_protected(comment->c, 1);
388 	if (unlikely(!c))
389 		return;
390 	set->ext_size -= sizeof(*c) + strlen(c->str) + 1;
391 	kfree_rcu(c, rcu);
392 	rcu_assign_pointer(comment->c, NULL);
393 }
394 
395 typedef void (*destroyer)(struct ip_set *, void *);
396 /* ipset data extension types, in size order */
397 
398 const struct ip_set_ext_type ip_set_extensions[] = {
399 	[IPSET_EXT_ID_COUNTER] = {
400 		.type	= IPSET_EXT_COUNTER,
401 		.flag	= IPSET_FLAG_WITH_COUNTERS,
402 		.len	= sizeof(struct ip_set_counter),
403 		.align	= __alignof__(struct ip_set_counter),
404 	},
405 	[IPSET_EXT_ID_TIMEOUT] = {
406 		.type	= IPSET_EXT_TIMEOUT,
407 		.len	= sizeof(unsigned long),
408 		.align	= __alignof__(unsigned long),
409 	},
410 	[IPSET_EXT_ID_SKBINFO] = {
411 		.type	= IPSET_EXT_SKBINFO,
412 		.flag	= IPSET_FLAG_WITH_SKBINFO,
413 		.len	= sizeof(struct ip_set_skbinfo),
414 		.align	= __alignof__(struct ip_set_skbinfo),
415 	},
416 	[IPSET_EXT_ID_COMMENT] = {
417 		.type	 = IPSET_EXT_COMMENT | IPSET_EXT_DESTROY,
418 		.flag	 = IPSET_FLAG_WITH_COMMENT,
419 		.len	 = sizeof(struct ip_set_comment),
420 		.align	 = __alignof__(struct ip_set_comment),
421 		.destroy = ip_set_comment_free,
422 	},
423 };
424 EXPORT_SYMBOL_GPL(ip_set_extensions);
425 
426 static bool
add_extension(enum ip_set_ext_id id,u32 flags,struct nlattr * tb[])427 add_extension(enum ip_set_ext_id id, u32 flags, struct nlattr *tb[])
428 {
429 	return ip_set_extensions[id].flag ?
430 		(flags & ip_set_extensions[id].flag) :
431 		!!tb[IPSET_ATTR_TIMEOUT];
432 }
433 
434 size_t
ip_set_elem_len(struct ip_set * set,struct nlattr * tb[],size_t len,size_t align)435 ip_set_elem_len(struct ip_set *set, struct nlattr *tb[], size_t len,
436 		size_t align)
437 {
438 	enum ip_set_ext_id id;
439 	u32 cadt_flags = 0;
440 
441 	if (tb[IPSET_ATTR_CADT_FLAGS])
442 		cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
443 	if (cadt_flags & IPSET_FLAG_WITH_FORCEADD)
444 		set->flags |= IPSET_CREATE_FLAG_FORCEADD;
445 	if (!align)
446 		align = 1;
447 	for (id = 0; id < IPSET_EXT_ID_MAX; id++) {
448 		if (!add_extension(id, cadt_flags, tb))
449 			continue;
450 		if (align < ip_set_extensions[id].align)
451 			align = ip_set_extensions[id].align;
452 		len = ALIGN(len, ip_set_extensions[id].align);
453 		set->offset[id] = len;
454 		set->extensions |= ip_set_extensions[id].type;
455 		len += ip_set_extensions[id].len;
456 	}
457 	return ALIGN(len, align);
458 }
459 EXPORT_SYMBOL_GPL(ip_set_elem_len);
460 
461 int
ip_set_get_extensions(struct ip_set * set,struct nlattr * tb[],struct ip_set_ext * ext)462 ip_set_get_extensions(struct ip_set *set, struct nlattr *tb[],
463 		      struct ip_set_ext *ext)
464 {
465 	u64 fullmark;
466 
467 	if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
468 		     !ip_set_optattr_netorder(tb, IPSET_ATTR_PACKETS) ||
469 		     !ip_set_optattr_netorder(tb, IPSET_ATTR_BYTES) ||
470 		     !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBMARK) ||
471 		     !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBPRIO) ||
472 		     !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBQUEUE)))
473 		return -IPSET_ERR_PROTOCOL;
474 
475 	if (tb[IPSET_ATTR_TIMEOUT]) {
476 		if (!SET_WITH_TIMEOUT(set))
477 			return -IPSET_ERR_TIMEOUT;
478 		ext->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
479 	}
480 	if (tb[IPSET_ATTR_BYTES] || tb[IPSET_ATTR_PACKETS]) {
481 		if (!SET_WITH_COUNTER(set))
482 			return -IPSET_ERR_COUNTER;
483 		if (tb[IPSET_ATTR_BYTES])
484 			ext->bytes = be64_to_cpu(nla_get_be64(
485 						 tb[IPSET_ATTR_BYTES]));
486 		if (tb[IPSET_ATTR_PACKETS])
487 			ext->packets = be64_to_cpu(nla_get_be64(
488 						   tb[IPSET_ATTR_PACKETS]));
489 	}
490 	if (tb[IPSET_ATTR_COMMENT]) {
491 		if (!SET_WITH_COMMENT(set))
492 			return -IPSET_ERR_COMMENT;
493 		ext->comment = ip_set_comment_uget(tb[IPSET_ATTR_COMMENT]);
494 	}
495 	if (tb[IPSET_ATTR_SKBMARK]) {
496 		if (!SET_WITH_SKBINFO(set))
497 			return -IPSET_ERR_SKBINFO;
498 		fullmark = be64_to_cpu(nla_get_be64(tb[IPSET_ATTR_SKBMARK]));
499 		ext->skbinfo.skbmark = fullmark >> 32;
500 		ext->skbinfo.skbmarkmask = fullmark & 0xffffffff;
501 	}
502 	if (tb[IPSET_ATTR_SKBPRIO]) {
503 		if (!SET_WITH_SKBINFO(set))
504 			return -IPSET_ERR_SKBINFO;
505 		ext->skbinfo.skbprio =
506 			be32_to_cpu(nla_get_be32(tb[IPSET_ATTR_SKBPRIO]));
507 	}
508 	if (tb[IPSET_ATTR_SKBQUEUE]) {
509 		if (!SET_WITH_SKBINFO(set))
510 			return -IPSET_ERR_SKBINFO;
511 		ext->skbinfo.skbqueue =
512 			be16_to_cpu(nla_get_be16(tb[IPSET_ATTR_SKBQUEUE]));
513 	}
514 	return 0;
515 }
516 EXPORT_SYMBOL_GPL(ip_set_get_extensions);
517 
518 static u64
ip_set_get_bytes(const struct ip_set_counter * counter)519 ip_set_get_bytes(const struct ip_set_counter *counter)
520 {
521 	return (u64)atomic64_read(&(counter)->bytes);
522 }
523 
524 static u64
ip_set_get_packets(const struct ip_set_counter * counter)525 ip_set_get_packets(const struct ip_set_counter *counter)
526 {
527 	return (u64)atomic64_read(&(counter)->packets);
528 }
529 
530 static bool
ip_set_put_counter(struct sk_buff * skb,const struct ip_set_counter * counter)531 ip_set_put_counter(struct sk_buff *skb, const struct ip_set_counter *counter)
532 {
533 	return nla_put_net64(skb, IPSET_ATTR_BYTES,
534 			     cpu_to_be64(ip_set_get_bytes(counter)),
535 			     IPSET_ATTR_PAD) ||
536 	       nla_put_net64(skb, IPSET_ATTR_PACKETS,
537 			     cpu_to_be64(ip_set_get_packets(counter)),
538 			     IPSET_ATTR_PAD);
539 }
540 
541 static bool
ip_set_put_skbinfo(struct sk_buff * skb,const struct ip_set_skbinfo * skbinfo)542 ip_set_put_skbinfo(struct sk_buff *skb, const struct ip_set_skbinfo *skbinfo)
543 {
544 	/* Send nonzero parameters only */
545 	return ((skbinfo->skbmark || skbinfo->skbmarkmask) &&
546 		nla_put_net64(skb, IPSET_ATTR_SKBMARK,
547 			      cpu_to_be64((u64)skbinfo->skbmark << 32 |
548 					  skbinfo->skbmarkmask),
549 			      IPSET_ATTR_PAD)) ||
550 	       (skbinfo->skbprio &&
551 		nla_put_net32(skb, IPSET_ATTR_SKBPRIO,
552 			      cpu_to_be32(skbinfo->skbprio))) ||
553 	       (skbinfo->skbqueue &&
554 		nla_put_net16(skb, IPSET_ATTR_SKBQUEUE,
555 			      cpu_to_be16(skbinfo->skbqueue)));
556 }
557 
558 int
ip_set_put_extensions(struct sk_buff * skb,const struct ip_set * set,const void * e,bool active)559 ip_set_put_extensions(struct sk_buff *skb, const struct ip_set *set,
560 		      const void *e, bool active)
561 {
562 	if (SET_WITH_TIMEOUT(set)) {
563 		unsigned long *timeout = ext_timeout(e, set);
564 
565 		if (nla_put_net32(skb, IPSET_ATTR_TIMEOUT,
566 			htonl(active ? ip_set_timeout_get(timeout)
567 				: *timeout)))
568 			return -EMSGSIZE;
569 	}
570 	if (SET_WITH_COUNTER(set) &&
571 	    ip_set_put_counter(skb, ext_counter(e, set)))
572 		return -EMSGSIZE;
573 	if (SET_WITH_COMMENT(set) &&
574 	    ip_set_put_comment(skb, ext_comment(e, set)))
575 		return -EMSGSIZE;
576 	if (SET_WITH_SKBINFO(set) &&
577 	    ip_set_put_skbinfo(skb, ext_skbinfo(e, set)))
578 		return -EMSGSIZE;
579 	return 0;
580 }
581 EXPORT_SYMBOL_GPL(ip_set_put_extensions);
582 
583 static bool
ip_set_match_counter(u64 counter,u64 match,u8 op)584 ip_set_match_counter(u64 counter, u64 match, u8 op)
585 {
586 	switch (op) {
587 	case IPSET_COUNTER_NONE:
588 		return true;
589 	case IPSET_COUNTER_EQ:
590 		return counter == match;
591 	case IPSET_COUNTER_NE:
592 		return counter != match;
593 	case IPSET_COUNTER_LT:
594 		return counter < match;
595 	case IPSET_COUNTER_GT:
596 		return counter > match;
597 	}
598 	return false;
599 }
600 
601 static void
ip_set_add_bytes(u64 bytes,struct ip_set_counter * counter)602 ip_set_add_bytes(u64 bytes, struct ip_set_counter *counter)
603 {
604 	atomic64_add((long long)bytes, &(counter)->bytes);
605 }
606 
607 static void
ip_set_add_packets(u64 packets,struct ip_set_counter * counter)608 ip_set_add_packets(u64 packets, struct ip_set_counter *counter)
609 {
610 	atomic64_add((long long)packets, &(counter)->packets);
611 }
612 
613 static void
ip_set_update_counter(struct ip_set_counter * counter,const struct ip_set_ext * ext,u32 flags)614 ip_set_update_counter(struct ip_set_counter *counter,
615 		      const struct ip_set_ext *ext, u32 flags)
616 {
617 	if (ext->packets != ULLONG_MAX &&
618 	    !(flags & IPSET_FLAG_SKIP_COUNTER_UPDATE)) {
619 		ip_set_add_bytes(ext->bytes, counter);
620 		ip_set_add_packets(ext->packets, counter);
621 	}
622 }
623 
624 static void
ip_set_get_skbinfo(struct ip_set_skbinfo * skbinfo,const struct ip_set_ext * ext,struct ip_set_ext * mext,u32 flags)625 ip_set_get_skbinfo(struct ip_set_skbinfo *skbinfo,
626 		   const struct ip_set_ext *ext,
627 		   struct ip_set_ext *mext, u32 flags)
628 {
629 	mext->skbinfo = *skbinfo;
630 }
631 
632 bool
ip_set_match_extensions(struct ip_set * set,const struct ip_set_ext * ext,struct ip_set_ext * mext,u32 flags,void * data)633 ip_set_match_extensions(struct ip_set *set, const struct ip_set_ext *ext,
634 			struct ip_set_ext *mext, u32 flags, void *data)
635 {
636 	if (SET_WITH_TIMEOUT(set) &&
637 	    ip_set_timeout_expired(ext_timeout(data, set)))
638 		return false;
639 	if (SET_WITH_COUNTER(set)) {
640 		struct ip_set_counter *counter = ext_counter(data, set);
641 
642 		ip_set_update_counter(counter, ext, flags);
643 
644 		if (flags & IPSET_FLAG_MATCH_COUNTERS &&
645 		    !(ip_set_match_counter(ip_set_get_packets(counter),
646 				mext->packets, mext->packets_op) &&
647 		      ip_set_match_counter(ip_set_get_bytes(counter),
648 				mext->bytes, mext->bytes_op)))
649 			return false;
650 	}
651 	if (SET_WITH_SKBINFO(set))
652 		ip_set_get_skbinfo(ext_skbinfo(data, set),
653 				   ext, mext, flags);
654 	return true;
655 }
656 EXPORT_SYMBOL_GPL(ip_set_match_extensions);
657 
658 /* Creating/destroying/renaming/swapping affect the existence and
659  * the properties of a set. All of these can be executed from userspace
660  * only and serialized by the nfnl mutex indirectly from nfnetlink.
661  *
662  * Sets are identified by their index in ip_set_list and the index
663  * is used by the external references (set/SET netfilter modules).
664  *
665  * The set behind an index may change by swapping only, from userspace.
666  */
667 
668 static void
__ip_set_get(struct ip_set * set)669 __ip_set_get(struct ip_set *set)
670 {
671 	write_lock_bh(&ip_set_ref_lock);
672 	set->ref++;
673 	write_unlock_bh(&ip_set_ref_lock);
674 }
675 
676 static void
__ip_set_put(struct ip_set * set)677 __ip_set_put(struct ip_set *set)
678 {
679 	write_lock_bh(&ip_set_ref_lock);
680 	BUG_ON(set->ref == 0);
681 	set->ref--;
682 	write_unlock_bh(&ip_set_ref_lock);
683 }
684 
685 /* set->ref can be swapped out by ip_set_swap, netlink events (like dump) need
686  * a separate reference counter
687  */
688 static void
__ip_set_get_netlink(struct ip_set * set)689 __ip_set_get_netlink(struct ip_set *set)
690 {
691 	write_lock_bh(&ip_set_ref_lock);
692 	set->ref_netlink++;
693 	write_unlock_bh(&ip_set_ref_lock);
694 }
695 
696 static void
__ip_set_put_netlink(struct ip_set * set)697 __ip_set_put_netlink(struct ip_set *set)
698 {
699 	write_lock_bh(&ip_set_ref_lock);
700 	BUG_ON(set->ref_netlink == 0);
701 	set->ref_netlink--;
702 	write_unlock_bh(&ip_set_ref_lock);
703 }
704 
705 /* Add, del and test set entries from kernel.
706  *
707  * The set behind the index must exist and must be referenced
708  * so it can't be destroyed (or changed) under our foot.
709  */
710 
711 static struct ip_set *
ip_set_rcu_get(struct net * net,ip_set_id_t index)712 ip_set_rcu_get(struct net *net, ip_set_id_t index)
713 {
714 	struct ip_set_net *inst = ip_set_pernet(net);
715 
716 	/* ip_set_list and the set pointer need to be protected */
717 	return ip_set_dereference_nfnl(inst->ip_set_list)[index];
718 }
719 
720 static inline void
ip_set_lock(struct ip_set * set)721 ip_set_lock(struct ip_set *set)
722 {
723 	if (!set->variant->region_lock)
724 		spin_lock_bh(&set->lock);
725 }
726 
727 static inline void
ip_set_unlock(struct ip_set * set)728 ip_set_unlock(struct ip_set *set)
729 {
730 	if (!set->variant->region_lock)
731 		spin_unlock_bh(&set->lock);
732 }
733 
734 int
ip_set_test(ip_set_id_t index,const struct sk_buff * skb,const struct xt_action_param * par,struct ip_set_adt_opt * opt)735 ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
736 	    const struct xt_action_param *par, struct ip_set_adt_opt *opt)
737 {
738 	struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
739 	int ret = 0;
740 
741 	BUG_ON(!set);
742 	pr_debug("set %s, index %u\n", set->name, index);
743 
744 	if (opt->dim < set->type->dimension ||
745 	    !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
746 		return 0;
747 
748 	ret = set->variant->kadt(set, skb, par, IPSET_TEST, opt);
749 
750 	if (ret == -EAGAIN) {
751 		/* Type requests element to be completed */
752 		pr_debug("element must be completed, ADD is triggered\n");
753 		ip_set_lock(set);
754 		set->variant->kadt(set, skb, par, IPSET_ADD, opt);
755 		ip_set_unlock(set);
756 		ret = 1;
757 	} else {
758 		/* --return-nomatch: invert matched element */
759 		if ((opt->cmdflags & IPSET_FLAG_RETURN_NOMATCH) &&
760 		    (set->type->features & IPSET_TYPE_NOMATCH) &&
761 		    (ret > 0 || ret == -ENOTEMPTY))
762 			ret = -ret;
763 	}
764 
765 	/* Convert error codes to nomatch */
766 	return (ret < 0 ? 0 : ret);
767 }
768 EXPORT_SYMBOL_GPL(ip_set_test);
769 
770 int
ip_set_add(ip_set_id_t index,const struct sk_buff * skb,const struct xt_action_param * par,struct ip_set_adt_opt * opt)771 ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
772 	   const struct xt_action_param *par, struct ip_set_adt_opt *opt)
773 {
774 	struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
775 	int ret;
776 
777 	BUG_ON(!set);
778 	pr_debug("set %s, index %u\n", set->name, index);
779 
780 	if (opt->dim < set->type->dimension ||
781 	    !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
782 		return -IPSET_ERR_TYPE_MISMATCH;
783 
784 	ip_set_lock(set);
785 	ret = set->variant->kadt(set, skb, par, IPSET_ADD, opt);
786 	ip_set_unlock(set);
787 
788 	return ret;
789 }
790 EXPORT_SYMBOL_GPL(ip_set_add);
791 
792 int
ip_set_del(ip_set_id_t index,const struct sk_buff * skb,const struct xt_action_param * par,struct ip_set_adt_opt * opt)793 ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
794 	   const struct xt_action_param *par, struct ip_set_adt_opt *opt)
795 {
796 	struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
797 	int ret = 0;
798 
799 	BUG_ON(!set);
800 	pr_debug("set %s, index %u\n", set->name, index);
801 
802 	if (opt->dim < set->type->dimension ||
803 	    !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
804 		return -IPSET_ERR_TYPE_MISMATCH;
805 
806 	ip_set_lock(set);
807 	ret = set->variant->kadt(set, skb, par, IPSET_DEL, opt);
808 	ip_set_unlock(set);
809 
810 	return ret;
811 }
812 EXPORT_SYMBOL_GPL(ip_set_del);
813 
814 /* Find set by name, reference it once. The reference makes sure the
815  * thing pointed to, does not go away under our feet.
816  *
817  */
818 ip_set_id_t
ip_set_get_byname(struct net * net,const char * name,struct ip_set ** set)819 ip_set_get_byname(struct net *net, const char *name, struct ip_set **set)
820 {
821 	ip_set_id_t i, index = IPSET_INVALID_ID;
822 	struct ip_set *s;
823 	struct ip_set_net *inst = ip_set_pernet(net);
824 
825 	rcu_read_lock();
826 	for (i = 0; i < inst->ip_set_max; i++) {
827 		s = rcu_dereference(inst->ip_set_list)[i];
828 		if (s && STRNCMP(s->name, name)) {
829 			__ip_set_get(s);
830 			index = i;
831 			*set = s;
832 			break;
833 		}
834 	}
835 	rcu_read_unlock();
836 
837 	return index;
838 }
839 EXPORT_SYMBOL_GPL(ip_set_get_byname);
840 
841 /* If the given set pointer points to a valid set, decrement
842  * reference count by 1. The caller shall not assume the index
843  * to be valid, after calling this function.
844  *
845  */
846 
847 static void
__ip_set_put_byindex(struct ip_set_net * inst,ip_set_id_t index)848 __ip_set_put_byindex(struct ip_set_net *inst, ip_set_id_t index)
849 {
850 	struct ip_set *set;
851 
852 	rcu_read_lock();
853 	set = rcu_dereference(inst->ip_set_list)[index];
854 	if (set)
855 		__ip_set_put(set);
856 	rcu_read_unlock();
857 }
858 
859 void
ip_set_put_byindex(struct net * net,ip_set_id_t index)860 ip_set_put_byindex(struct net *net, ip_set_id_t index)
861 {
862 	struct ip_set_net *inst = ip_set_pernet(net);
863 
864 	__ip_set_put_byindex(inst, index);
865 }
866 EXPORT_SYMBOL_GPL(ip_set_put_byindex);
867 
868 /* Get the name of a set behind a set index.
869  * Set itself is protected by RCU, but its name isn't: to protect against
870  * renaming, grab ip_set_ref_lock as reader (see ip_set_rename()) and copy the
871  * name.
872  */
873 void
ip_set_name_byindex(struct net * net,ip_set_id_t index,char * name)874 ip_set_name_byindex(struct net *net, ip_set_id_t index, char *name)
875 {
876 	struct ip_set *set = ip_set_rcu_get(net, index);
877 
878 	BUG_ON(!set);
879 
880 	read_lock_bh(&ip_set_ref_lock);
881 	strscpy_pad(name, set->name, IPSET_MAXNAMELEN);
882 	read_unlock_bh(&ip_set_ref_lock);
883 }
884 EXPORT_SYMBOL_GPL(ip_set_name_byindex);
885 
886 /* Routines to call by external subsystems, which do not
887  * call nfnl_lock for us.
888  */
889 
890 /* Find set by index, reference it once. The reference makes sure the
891  * thing pointed to, does not go away under our feet.
892  *
893  * The nfnl mutex is used in the function.
894  */
895 ip_set_id_t
ip_set_nfnl_get_byindex(struct net * net,ip_set_id_t index)896 ip_set_nfnl_get_byindex(struct net *net, ip_set_id_t index)
897 {
898 	struct ip_set *set;
899 	struct ip_set_net *inst = ip_set_pernet(net);
900 
901 	if (index >= inst->ip_set_max)
902 		return IPSET_INVALID_ID;
903 
904 	nfnl_lock(NFNL_SUBSYS_IPSET);
905 	set = ip_set(inst, index);
906 	if (set)
907 		__ip_set_get(set);
908 	else
909 		index = IPSET_INVALID_ID;
910 	nfnl_unlock(NFNL_SUBSYS_IPSET);
911 
912 	return index;
913 }
914 EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
915 
916 /* If the given set pointer points to a valid set, decrement
917  * reference count by 1. The caller shall not assume the index
918  * to be valid, after calling this function.
919  *
920  * The nfnl mutex is used in the function.
921  */
922 void
ip_set_nfnl_put(struct net * net,ip_set_id_t index)923 ip_set_nfnl_put(struct net *net, ip_set_id_t index)
924 {
925 	struct ip_set *set;
926 	struct ip_set_net *inst = ip_set_pernet(net);
927 
928 	nfnl_lock(NFNL_SUBSYS_IPSET);
929 	if (!inst->is_deleted) { /* already deleted from ip_set_net_exit() */
930 		set = ip_set(inst, index);
931 		if (set)
932 			__ip_set_put(set);
933 	}
934 	nfnl_unlock(NFNL_SUBSYS_IPSET);
935 }
936 EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
937 
938 /* Communication protocol with userspace over netlink.
939  *
940  * The commands are serialized by the nfnl mutex.
941  */
942 
protocol(const struct nlattr * const tb[])943 static inline u8 protocol(const struct nlattr * const tb[])
944 {
945 	return nla_get_u8(tb[IPSET_ATTR_PROTOCOL]);
946 }
947 
948 static inline bool
protocol_failed(const struct nlattr * const tb[])949 protocol_failed(const struct nlattr * const tb[])
950 {
951 	return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) != IPSET_PROTOCOL;
952 }
953 
954 static inline bool
protocol_min_failed(const struct nlattr * const tb[])955 protocol_min_failed(const struct nlattr * const tb[])
956 {
957 	return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) < IPSET_PROTOCOL_MIN;
958 }
959 
960 static inline u32
flag_exist(const struct nlmsghdr * nlh)961 flag_exist(const struct nlmsghdr *nlh)
962 {
963 	return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
964 }
965 
966 static struct nlmsghdr *
start_msg(struct sk_buff * skb,u32 portid,u32 seq,unsigned int flags,enum ipset_cmd cmd)967 start_msg(struct sk_buff *skb, u32 portid, u32 seq, unsigned int flags,
968 	  enum ipset_cmd cmd)
969 {
970 	return nfnl_msg_put(skb, portid, seq,
971 			    nfnl_msg_type(NFNL_SUBSYS_IPSET, cmd), flags,
972 			    NFPROTO_IPV4, NFNETLINK_V0, 0);
973 }
974 
975 /* Create a set */
976 
977 static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
978 	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
979 	[IPSET_ATTR_SETNAME]	= { .type = NLA_NUL_STRING,
980 				    .len = IPSET_MAXNAMELEN - 1 },
981 	[IPSET_ATTR_TYPENAME]	= { .type = NLA_NUL_STRING,
982 				    .len = IPSET_MAXNAMELEN - 1},
983 	[IPSET_ATTR_REVISION]	= { .type = NLA_U8 },
984 	[IPSET_ATTR_FAMILY]	= { .type = NLA_U8 },
985 	[IPSET_ATTR_DATA]	= { .type = NLA_NESTED },
986 };
987 
988 static struct ip_set *
find_set_and_id(struct ip_set_net * inst,const char * name,ip_set_id_t * id)989 find_set_and_id(struct ip_set_net *inst, const char *name, ip_set_id_t *id)
990 {
991 	struct ip_set *set = NULL;
992 	ip_set_id_t i;
993 
994 	*id = IPSET_INVALID_ID;
995 	for (i = 0; i < inst->ip_set_max; i++) {
996 		set = ip_set(inst, i);
997 		if (set && STRNCMP(set->name, name)) {
998 			*id = i;
999 			break;
1000 		}
1001 	}
1002 	return (*id == IPSET_INVALID_ID ? NULL : set);
1003 }
1004 
1005 static inline struct ip_set *
find_set(struct ip_set_net * inst,const char * name)1006 find_set(struct ip_set_net *inst, const char *name)
1007 {
1008 	ip_set_id_t id;
1009 
1010 	return find_set_and_id(inst, name, &id);
1011 }
1012 
1013 static int
find_free_id(struct ip_set_net * inst,const char * name,ip_set_id_t * index,struct ip_set ** set)1014 find_free_id(struct ip_set_net *inst, const char *name, ip_set_id_t *index,
1015 	     struct ip_set **set)
1016 {
1017 	struct ip_set *s;
1018 	ip_set_id_t i;
1019 
1020 	*index = IPSET_INVALID_ID;
1021 	for (i = 0;  i < inst->ip_set_max; i++) {
1022 		s = ip_set(inst, i);
1023 		if (!s) {
1024 			if (*index == IPSET_INVALID_ID)
1025 				*index = i;
1026 		} else if (STRNCMP(name, s->name)) {
1027 			/* Name clash */
1028 			*set = s;
1029 			return -EEXIST;
1030 		}
1031 	}
1032 	if (*index == IPSET_INVALID_ID)
1033 		/* No free slot remained */
1034 		return -IPSET_ERR_MAX_SETS;
1035 	return 0;
1036 }
1037 
ip_set_none(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1038 static int ip_set_none(struct sk_buff *skb, const struct nfnl_info *info,
1039 		       const struct nlattr * const attr[])
1040 {
1041 	return -EOPNOTSUPP;
1042 }
1043 
ip_set_create(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1044 static int ip_set_create(struct sk_buff *skb, const struct nfnl_info *info,
1045 			 const struct nlattr * const attr[])
1046 {
1047 	struct ip_set_net *inst = ip_set_pernet(info->net);
1048 	struct ip_set *set, *clash = NULL;
1049 	ip_set_id_t index = IPSET_INVALID_ID;
1050 	struct nlattr *tb[IPSET_ATTR_CREATE_MAX + 1] = {};
1051 	const char *name, *typename;
1052 	u8 family, revision;
1053 	u32 flags = flag_exist(info->nlh);
1054 	int ret = 0;
1055 
1056 	if (unlikely(protocol_min_failed(attr) ||
1057 		     !attr[IPSET_ATTR_SETNAME] ||
1058 		     !attr[IPSET_ATTR_TYPENAME] ||
1059 		     !attr[IPSET_ATTR_REVISION] ||
1060 		     !attr[IPSET_ATTR_FAMILY] ||
1061 		     (attr[IPSET_ATTR_DATA] &&
1062 		      !flag_nested(attr[IPSET_ATTR_DATA]))))
1063 		return -IPSET_ERR_PROTOCOL;
1064 
1065 	name = nla_data(attr[IPSET_ATTR_SETNAME]);
1066 	typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1067 	family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1068 	revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
1069 	pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
1070 		 name, typename, family_name(family), revision);
1071 
1072 	/* First, and without any locks, allocate and initialize
1073 	 * a normal base set structure.
1074 	 */
1075 	set = kzalloc(sizeof(*set), GFP_KERNEL);
1076 	if (!set)
1077 		return -ENOMEM;
1078 	spin_lock_init(&set->lock);
1079 	strscpy(set->name, name, IPSET_MAXNAMELEN);
1080 	set->family = family;
1081 	set->revision = revision;
1082 
1083 	/* Next, check that we know the type, and take
1084 	 * a reference on the type, to make sure it stays available
1085 	 * while constructing our new set.
1086 	 *
1087 	 * After referencing the type, we try to create the type
1088 	 * specific part of the set without holding any locks.
1089 	 */
1090 	ret = find_set_type_get(typename, family, revision, &set->type);
1091 	if (ret)
1092 		goto out;
1093 
1094 	/* Without holding any locks, create private part. */
1095 	if (attr[IPSET_ATTR_DATA] &&
1096 	    nla_parse_nested(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA],
1097 			     set->type->create_policy, NULL)) {
1098 		ret = -IPSET_ERR_PROTOCOL;
1099 		goto put_out;
1100 	}
1101 	/* Set create flags depending on the type revision */
1102 	set->flags |= set->type->create_flags[revision];
1103 
1104 	ret = set->type->create(info->net, set, tb, flags);
1105 	if (ret != 0)
1106 		goto put_out;
1107 
1108 	/* BTW, ret==0 here. */
1109 
1110 	/* Here, we have a valid, constructed set and we are protected
1111 	 * by the nfnl mutex. Find the first free index in ip_set_list
1112 	 * and check clashing.
1113 	 */
1114 	ret = find_free_id(inst, set->name, &index, &clash);
1115 	if (ret == -EEXIST) {
1116 		/* If this is the same set and requested, ignore error */
1117 		if ((flags & IPSET_FLAG_EXIST) &&
1118 		    STRNCMP(set->type->name, clash->type->name) &&
1119 		    set->type->family == clash->type->family &&
1120 		    set->type->revision_min == clash->type->revision_min &&
1121 		    set->type->revision_max == clash->type->revision_max &&
1122 		    set->variant->same_set(set, clash))
1123 			ret = 0;
1124 		goto cleanup;
1125 	} else if (ret == -IPSET_ERR_MAX_SETS) {
1126 		struct ip_set **list, **tmp;
1127 		ip_set_id_t i = inst->ip_set_max + IP_SET_INC;
1128 
1129 		if (i < inst->ip_set_max || i == IPSET_INVALID_ID)
1130 			/* Wraparound */
1131 			goto cleanup;
1132 
1133 		list = kvcalloc(i, sizeof(struct ip_set *), GFP_KERNEL);
1134 		if (!list)
1135 			goto cleanup;
1136 		/* nfnl mutex is held, both lists are valid */
1137 		tmp = ip_set_dereference(inst);
1138 		memcpy(list, tmp, sizeof(struct ip_set *) * inst->ip_set_max);
1139 		rcu_assign_pointer(inst->ip_set_list, list);
1140 		/* Make sure all current packets have passed through */
1141 		synchronize_net();
1142 		/* Use new list */
1143 		index = inst->ip_set_max;
1144 		inst->ip_set_max = i;
1145 		kvfree(tmp);
1146 		ret = 0;
1147 	} else if (ret) {
1148 		goto cleanup;
1149 	}
1150 
1151 	/* Finally! Add our shiny new set to the list, and be done. */
1152 	pr_debug("create: '%s' created with index %u!\n", set->name, index);
1153 	ip_set(inst, index) = set;
1154 
1155 	return ret;
1156 
1157 cleanup:
1158 	set->variant->cancel_gc(set);
1159 	set->variant->destroy(set);
1160 put_out:
1161 	module_put(set->type->me);
1162 out:
1163 	kfree(set);
1164 	return ret;
1165 }
1166 
1167 /* Destroy sets */
1168 
1169 static const struct nla_policy
1170 ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
1171 	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
1172 	[IPSET_ATTR_SETNAME]	= { .type = NLA_NUL_STRING,
1173 				    .len = IPSET_MAXNAMELEN - 1 },
1174 };
1175 
1176 /* In order to return quickly when destroying a single set, it is split
1177  * into two stages:
1178  * - Cancel garbage collector
1179  * - Destroy the set itself via call_rcu()
1180  */
1181 
1182 static void
ip_set_destroy_set_rcu(struct rcu_head * head)1183 ip_set_destroy_set_rcu(struct rcu_head *head)
1184 {
1185 	struct ip_set *set = container_of(head, struct ip_set, rcu);
1186 
1187 	set->variant->destroy(set);
1188 	module_put(set->type->me);
1189 	kfree(set);
1190 }
1191 
1192 static void
_destroy_all_sets(struct ip_set_net * inst)1193 _destroy_all_sets(struct ip_set_net *inst)
1194 {
1195 	struct ip_set *set;
1196 	ip_set_id_t i;
1197 	bool need_wait = false;
1198 
1199 	/* First cancel gc's: set:list sets are flushed as well */
1200 	for (i = 0; i < inst->ip_set_max; i++) {
1201 		set = ip_set(inst, i);
1202 		if (set) {
1203 			set->variant->cancel_gc(set);
1204 			if (set->type->features & IPSET_TYPE_NAME)
1205 				need_wait = true;
1206 		}
1207 	}
1208 	/* Must wait for flush to be really finished  */
1209 	if (need_wait)
1210 		rcu_barrier();
1211 	for (i = 0; i < inst->ip_set_max; i++) {
1212 		set = ip_set(inst, i);
1213 		if (set) {
1214 			ip_set(inst, i) = NULL;
1215 			set->variant->destroy(set);
1216 			module_put(set->type->me);
1217 			kfree(set);
1218 		}
1219 	}
1220 }
1221 
ip_set_destroy(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1222 static int ip_set_destroy(struct sk_buff *skb, const struct nfnl_info *info,
1223 			  const struct nlattr * const attr[])
1224 {
1225 	struct ip_set_net *inst = ip_set_pernet(info->net);
1226 	struct ip_set *s;
1227 	ip_set_id_t i;
1228 	int ret = 0;
1229 
1230 	if (unlikely(protocol_min_failed(attr)))
1231 		return -IPSET_ERR_PROTOCOL;
1232 
1233 	/* Commands are serialized and references are
1234 	 * protected by the ip_set_ref_lock.
1235 	 * External systems (i.e. xt_set) must call
1236 	 * ip_set_nfnl_get_* functions, that way we
1237 	 * can safely check references here.
1238 	 *
1239 	 * list:set timer can only decrement the reference
1240 	 * counter, so if it's already zero, we can proceed
1241 	 * without holding the lock.
1242 	 */
1243 	if (!attr[IPSET_ATTR_SETNAME]) {
1244 		read_lock_bh(&ip_set_ref_lock);
1245 		for (i = 0; i < inst->ip_set_max; i++) {
1246 			s = ip_set(inst, i);
1247 			if (s && (s->ref || s->ref_netlink)) {
1248 				ret = -IPSET_ERR_BUSY;
1249 				goto out;
1250 			}
1251 		}
1252 		inst->is_destroyed = true;
1253 		read_unlock_bh(&ip_set_ref_lock);
1254 		_destroy_all_sets(inst);
1255 		/* Modified by ip_set_destroy() only, which is serialized */
1256 		inst->is_destroyed = false;
1257 	} else {
1258 		u32 flags = flag_exist(info->nlh);
1259 		u16 features = 0;
1260 
1261 		read_lock_bh(&ip_set_ref_lock);
1262 		s = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1263 				    &i);
1264 		if (!s) {
1265 			if (!(flags & IPSET_FLAG_EXIST))
1266 				ret = -ENOENT;
1267 			goto out;
1268 		} else if (s->ref || s->ref_netlink) {
1269 			ret = -IPSET_ERR_BUSY;
1270 			goto out;
1271 		}
1272 		features = s->type->features;
1273 		ip_set(inst, i) = NULL;
1274 		read_unlock_bh(&ip_set_ref_lock);
1275 		/* Must cancel garbage collectors */
1276 		s->variant->cancel_gc(s);
1277 		if (features & IPSET_TYPE_NAME) {
1278 			/* Must wait for flush to be really finished  */
1279 			rcu_barrier();
1280 		}
1281 		call_rcu(&s->rcu, ip_set_destroy_set_rcu);
1282 	}
1283 	return 0;
1284 out:
1285 	read_unlock_bh(&ip_set_ref_lock);
1286 	return ret;
1287 }
1288 
1289 /* Flush sets */
1290 
1291 static void
ip_set_flush_set(struct ip_set * set)1292 ip_set_flush_set(struct ip_set *set)
1293 {
1294 	pr_debug("set: %s\n",  set->name);
1295 
1296 	ip_set_lock(set);
1297 	set->variant->flush(set);
1298 	ip_set_unlock(set);
1299 }
1300 
ip_set_flush(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1301 static int ip_set_flush(struct sk_buff *skb, const struct nfnl_info *info,
1302 			const struct nlattr * const attr[])
1303 {
1304 	struct ip_set_net *inst = ip_set_pernet(info->net);
1305 	struct ip_set *s;
1306 	ip_set_id_t i;
1307 
1308 	if (unlikely(protocol_min_failed(attr)))
1309 		return -IPSET_ERR_PROTOCOL;
1310 
1311 	if (!attr[IPSET_ATTR_SETNAME]) {
1312 		for (i = 0; i < inst->ip_set_max; i++) {
1313 			s = ip_set(inst, i);
1314 			if (s)
1315 				ip_set_flush_set(s);
1316 		}
1317 	} else {
1318 		s = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1319 		if (!s)
1320 			return -ENOENT;
1321 
1322 		ip_set_flush_set(s);
1323 	}
1324 
1325 	return 0;
1326 }
1327 
1328 /* Rename a set */
1329 
1330 static const struct nla_policy
1331 ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
1332 	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
1333 	[IPSET_ATTR_SETNAME]	= { .type = NLA_NUL_STRING,
1334 				    .len = IPSET_MAXNAMELEN - 1 },
1335 	[IPSET_ATTR_SETNAME2]	= { .type = NLA_NUL_STRING,
1336 				    .len = IPSET_MAXNAMELEN - 1 },
1337 };
1338 
ip_set_rename(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1339 static int ip_set_rename(struct sk_buff *skb, const struct nfnl_info *info,
1340 			 const struct nlattr * const attr[])
1341 {
1342 	struct ip_set_net *inst = ip_set_pernet(info->net);
1343 	struct ip_set *set, *s;
1344 	const char *name2;
1345 	ip_set_id_t i;
1346 	int ret = 0;
1347 
1348 	if (unlikely(protocol_min_failed(attr) ||
1349 		     !attr[IPSET_ATTR_SETNAME] ||
1350 		     !attr[IPSET_ATTR_SETNAME2]))
1351 		return -IPSET_ERR_PROTOCOL;
1352 
1353 	set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1354 	if (!set)
1355 		return -ENOENT;
1356 
1357 	write_lock_bh(&ip_set_ref_lock);
1358 	if (set->ref != 0 || set->ref_netlink != 0) {
1359 		ret = -IPSET_ERR_REFERENCED;
1360 		goto out;
1361 	}
1362 
1363 	name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
1364 	for (i = 0; i < inst->ip_set_max; i++) {
1365 		s = ip_set(inst, i);
1366 		if (s && STRNCMP(s->name, name2)) {
1367 			ret = -IPSET_ERR_EXIST_SETNAME2;
1368 			goto out;
1369 		}
1370 	}
1371 	strscpy_pad(set->name, name2, IPSET_MAXNAMELEN);
1372 
1373 out:
1374 	write_unlock_bh(&ip_set_ref_lock);
1375 	return ret;
1376 }
1377 
1378 /* Swap two sets so that name/index points to the other.
1379  * References and set names are also swapped.
1380  *
1381  * The commands are serialized by the nfnl mutex and references are
1382  * protected by the ip_set_ref_lock. The kernel interfaces
1383  * do not hold the mutex but the pointer settings are atomic
1384  * so the ip_set_list always contains valid pointers to the sets.
1385  */
1386 
ip_set_swap(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1387 static int ip_set_swap(struct sk_buff *skb, const struct nfnl_info *info,
1388 		       const struct nlattr * const attr[])
1389 {
1390 	struct ip_set_net *inst = ip_set_pernet(info->net);
1391 	struct ip_set *from, *to;
1392 	ip_set_id_t from_id, to_id;
1393 	char from_name[IPSET_MAXNAMELEN];
1394 
1395 	if (unlikely(protocol_min_failed(attr) ||
1396 		     !attr[IPSET_ATTR_SETNAME] ||
1397 		     !attr[IPSET_ATTR_SETNAME2]))
1398 		return -IPSET_ERR_PROTOCOL;
1399 
1400 	from = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1401 			       &from_id);
1402 	if (!from)
1403 		return -ENOENT;
1404 
1405 	to = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME2]),
1406 			     &to_id);
1407 	if (!to)
1408 		return -IPSET_ERR_EXIST_SETNAME2;
1409 
1410 	/* Features must not change.
1411 	 * Not an artifical restriction anymore, as we must prevent
1412 	 * possible loops created by swapping in setlist type of sets.
1413 	 */
1414 	if (!(from->type->features == to->type->features &&
1415 	      from->family == to->family))
1416 		return -IPSET_ERR_TYPE_MISMATCH;
1417 
1418 	write_lock_bh(&ip_set_ref_lock);
1419 
1420 	if (from->ref_netlink || to->ref_netlink) {
1421 		write_unlock_bh(&ip_set_ref_lock);
1422 		return -EBUSY;
1423 	}
1424 
1425 	strscpy_pad(from_name, from->name, IPSET_MAXNAMELEN);
1426 	strscpy_pad(from->name, to->name, IPSET_MAXNAMELEN);
1427 	strscpy_pad(to->name, from_name, IPSET_MAXNAMELEN);
1428 
1429 	swap(from->ref, to->ref);
1430 	ip_set(inst, from_id) = to;
1431 	ip_set(inst, to_id) = from;
1432 	write_unlock_bh(&ip_set_ref_lock);
1433 
1434 	return 0;
1435 }
1436 
1437 /* List/save set data */
1438 
1439 #define DUMP_INIT	0
1440 #define DUMP_ALL	1
1441 #define DUMP_ONE	2
1442 #define DUMP_LAST	3
1443 
1444 #define DUMP_TYPE(arg)		(((u32)(arg)) & 0x0000FFFF)
1445 #define DUMP_FLAGS(arg)		(((u32)(arg)) >> 16)
1446 
1447 int
ip_set_put_flags(struct sk_buff * skb,struct ip_set * set)1448 ip_set_put_flags(struct sk_buff *skb, struct ip_set *set)
1449 {
1450 	u32 cadt_flags = 0;
1451 
1452 	if (SET_WITH_TIMEOUT(set))
1453 		if (unlikely(nla_put_net32(skb, IPSET_ATTR_TIMEOUT,
1454 					   htonl(set->timeout))))
1455 			return -EMSGSIZE;
1456 	if (SET_WITH_COUNTER(set))
1457 		cadt_flags |= IPSET_FLAG_WITH_COUNTERS;
1458 	if (SET_WITH_COMMENT(set))
1459 		cadt_flags |= IPSET_FLAG_WITH_COMMENT;
1460 	if (SET_WITH_SKBINFO(set))
1461 		cadt_flags |= IPSET_FLAG_WITH_SKBINFO;
1462 	if (SET_WITH_FORCEADD(set))
1463 		cadt_flags |= IPSET_FLAG_WITH_FORCEADD;
1464 
1465 	if (!cadt_flags)
1466 		return 0;
1467 	return nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(cadt_flags));
1468 }
1469 EXPORT_SYMBOL_GPL(ip_set_put_flags);
1470 
1471 static int
ip_set_dump_done(struct netlink_callback * cb)1472 ip_set_dump_done(struct netlink_callback *cb)
1473 {
1474 	if (cb->args[IPSET_CB_ARG0]) {
1475 		struct ip_set_net *inst =
1476 			(struct ip_set_net *)cb->args[IPSET_CB_NET];
1477 		ip_set_id_t index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1478 		struct ip_set *set = ip_set_ref_netlink(inst, index);
1479 
1480 		if (set->variant->uref)
1481 			set->variant->uref(set, cb, false);
1482 		pr_debug("release set %s\n", set->name);
1483 		__ip_set_put_netlink(set);
1484 	}
1485 	return 0;
1486 }
1487 
1488 static inline void
dump_attrs(struct nlmsghdr * nlh)1489 dump_attrs(struct nlmsghdr *nlh)
1490 {
1491 	const struct nlattr *attr;
1492 	int rem;
1493 
1494 	pr_debug("dump nlmsg\n");
1495 	nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
1496 		pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
1497 	}
1498 }
1499 
1500 static const struct nla_policy
1501 ip_set_dump_policy[IPSET_ATTR_CMD_MAX + 1] = {
1502 	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
1503 	[IPSET_ATTR_SETNAME]	= { .type = NLA_NUL_STRING,
1504 				    .len = IPSET_MAXNAMELEN - 1 },
1505 	[IPSET_ATTR_FLAGS]	= { .type = NLA_U32 },
1506 };
1507 
1508 static int
ip_set_dump_start(struct netlink_callback * cb)1509 ip_set_dump_start(struct netlink_callback *cb)
1510 {
1511 	struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
1512 	int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1513 	struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1514 	struct nlattr *attr = (void *)nlh + min_len;
1515 	struct sk_buff *skb = cb->skb;
1516 	struct ip_set_net *inst = ip_set_pernet(sock_net(skb->sk));
1517 	u32 dump_type;
1518 	int ret;
1519 
1520 	ret = nla_parse(cda, IPSET_ATTR_CMD_MAX, attr,
1521 			nlh->nlmsg_len - min_len,
1522 			ip_set_dump_policy, NULL);
1523 	if (ret)
1524 		goto error;
1525 
1526 	cb->args[IPSET_CB_PROTO] = nla_get_u8(cda[IPSET_ATTR_PROTOCOL]);
1527 	if (cda[IPSET_ATTR_SETNAME]) {
1528 		ip_set_id_t index;
1529 		struct ip_set *set;
1530 
1531 		set = find_set_and_id(inst, nla_data(cda[IPSET_ATTR_SETNAME]),
1532 				      &index);
1533 		if (!set) {
1534 			ret = -ENOENT;
1535 			goto error;
1536 		}
1537 		dump_type = DUMP_ONE;
1538 		cb->args[IPSET_CB_INDEX] = index;
1539 	} else {
1540 		dump_type = DUMP_ALL;
1541 	}
1542 
1543 	if (cda[IPSET_ATTR_FLAGS]) {
1544 		u32 f = ip_set_get_h32(cda[IPSET_ATTR_FLAGS]);
1545 
1546 		dump_type |= (f << 16);
1547 	}
1548 	cb->args[IPSET_CB_NET] = (unsigned long)inst;
1549 	cb->args[IPSET_CB_DUMP] = dump_type;
1550 
1551 	return 0;
1552 
1553 error:
1554 	/* We have to create and send the error message manually :-( */
1555 	if (nlh->nlmsg_flags & NLM_F_ACK) {
1556 		netlink_ack(cb->skb, nlh, ret, NULL);
1557 	}
1558 	return ret;
1559 }
1560 
1561 static int
ip_set_dump_do(struct sk_buff * skb,struct netlink_callback * cb)1562 ip_set_dump_do(struct sk_buff *skb, struct netlink_callback *cb)
1563 {
1564 	ip_set_id_t index = IPSET_INVALID_ID, max;
1565 	struct ip_set *set = NULL;
1566 	struct nlmsghdr *nlh = NULL;
1567 	unsigned int flags = NETLINK_CB(cb->skb).portid ? NLM_F_MULTI : 0;
1568 	struct ip_set_net *inst = ip_set_pernet(sock_net(skb->sk));
1569 	u32 dump_type, dump_flags;
1570 	bool is_destroyed;
1571 	int ret = 0;
1572 
1573 	if (!cb->args[IPSET_CB_DUMP])
1574 		return -EINVAL;
1575 
1576 	if (cb->args[IPSET_CB_INDEX] >= inst->ip_set_max)
1577 		goto out;
1578 
1579 	dump_type = DUMP_TYPE(cb->args[IPSET_CB_DUMP]);
1580 	dump_flags = DUMP_FLAGS(cb->args[IPSET_CB_DUMP]);
1581 	max = dump_type == DUMP_ONE ? cb->args[IPSET_CB_INDEX] + 1
1582 				    : inst->ip_set_max;
1583 dump_last:
1584 	pr_debug("dump type, flag: %u %u index: %ld\n",
1585 		 dump_type, dump_flags, cb->args[IPSET_CB_INDEX]);
1586 	for (; cb->args[IPSET_CB_INDEX] < max; cb->args[IPSET_CB_INDEX]++) {
1587 		index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1588 		write_lock_bh(&ip_set_ref_lock);
1589 		set = ip_set(inst, index);
1590 		is_destroyed = inst->is_destroyed;
1591 		if (!set || is_destroyed) {
1592 			write_unlock_bh(&ip_set_ref_lock);
1593 			if (dump_type == DUMP_ONE) {
1594 				ret = -ENOENT;
1595 				goto out;
1596 			}
1597 			if (is_destroyed) {
1598 				/* All sets are just being destroyed */
1599 				ret = 0;
1600 				goto out;
1601 			}
1602 			continue;
1603 		}
1604 		/* When dumping all sets, we must dump "sorted"
1605 		 * so that lists (unions of sets) are dumped last.
1606 		 */
1607 		if (dump_type != DUMP_ONE &&
1608 		    ((dump_type == DUMP_ALL) ==
1609 		     !!(set->type->features & IPSET_DUMP_LAST))) {
1610 			write_unlock_bh(&ip_set_ref_lock);
1611 			continue;
1612 		}
1613 		pr_debug("List set: %s\n", set->name);
1614 		if (!cb->args[IPSET_CB_ARG0]) {
1615 			/* Start listing: make sure set won't be destroyed */
1616 			pr_debug("reference set\n");
1617 			set->ref_netlink++;
1618 		}
1619 		write_unlock_bh(&ip_set_ref_lock);
1620 		nlh = start_msg(skb, NETLINK_CB(cb->skb).portid,
1621 				cb->nlh->nlmsg_seq, flags,
1622 				IPSET_CMD_LIST);
1623 		if (!nlh) {
1624 			ret = -EMSGSIZE;
1625 			goto release_refcount;
1626 		}
1627 		if (nla_put_u8(skb, IPSET_ATTR_PROTOCOL,
1628 			       cb->args[IPSET_CB_PROTO]) ||
1629 		    nla_put_string(skb, IPSET_ATTR_SETNAME, set->name))
1630 			goto nla_put_failure;
1631 		if (dump_flags & IPSET_FLAG_LIST_SETNAME)
1632 			goto next_set;
1633 		switch (cb->args[IPSET_CB_ARG0]) {
1634 		case 0:
1635 			/* Core header data */
1636 			if (nla_put_string(skb, IPSET_ATTR_TYPENAME,
1637 					   set->type->name) ||
1638 			    nla_put_u8(skb, IPSET_ATTR_FAMILY,
1639 				       set->family) ||
1640 			    nla_put_u8(skb, IPSET_ATTR_REVISION,
1641 				       set->revision))
1642 				goto nla_put_failure;
1643 			if (cb->args[IPSET_CB_PROTO] > IPSET_PROTOCOL_MIN &&
1644 			    nla_put_net16(skb, IPSET_ATTR_INDEX, htons(index)))
1645 				goto nla_put_failure;
1646 			ret = set->variant->head(set, skb);
1647 			if (ret < 0)
1648 				goto release_refcount;
1649 			if (dump_flags & IPSET_FLAG_LIST_HEADER)
1650 				goto next_set;
1651 			if (set->variant->uref)
1652 				set->variant->uref(set, cb, true);
1653 			fallthrough;
1654 		default:
1655 			ret = set->variant->list(set, skb, cb);
1656 			if (!cb->args[IPSET_CB_ARG0])
1657 				/* Set is done, proceed with next one */
1658 				goto next_set;
1659 			goto release_refcount;
1660 		}
1661 	}
1662 	/* If we dump all sets, continue with dumping last ones */
1663 	if (dump_type == DUMP_ALL) {
1664 		dump_type = DUMP_LAST;
1665 		cb->args[IPSET_CB_DUMP] = dump_type | (dump_flags << 16);
1666 		cb->args[IPSET_CB_INDEX] = 0;
1667 		if (set && set->variant->uref)
1668 			set->variant->uref(set, cb, false);
1669 		goto dump_last;
1670 	}
1671 	goto out;
1672 
1673 nla_put_failure:
1674 	ret = -EFAULT;
1675 next_set:
1676 	if (dump_type == DUMP_ONE)
1677 		cb->args[IPSET_CB_INDEX] = IPSET_INVALID_ID;
1678 	else
1679 		cb->args[IPSET_CB_INDEX]++;
1680 release_refcount:
1681 	/* If there was an error or set is done, release set */
1682 	if (ret || !cb->args[IPSET_CB_ARG0]) {
1683 		set = ip_set_ref_netlink(inst, index);
1684 		if (set->variant->uref)
1685 			set->variant->uref(set, cb, false);
1686 		pr_debug("release set %s\n", set->name);
1687 		__ip_set_put_netlink(set);
1688 		cb->args[IPSET_CB_ARG0] = 0;
1689 	}
1690 out:
1691 	if (nlh) {
1692 		nlmsg_end(skb, nlh);
1693 		pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1694 		dump_attrs(nlh);
1695 	}
1696 
1697 	return ret < 0 ? ret : skb->len;
1698 }
1699 
ip_set_dump(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1700 static int ip_set_dump(struct sk_buff *skb, const struct nfnl_info *info,
1701 		       const struct nlattr * const attr[])
1702 {
1703 	if (unlikely(protocol_min_failed(attr)))
1704 		return -IPSET_ERR_PROTOCOL;
1705 
1706 	{
1707 		struct netlink_dump_control c = {
1708 			.start = ip_set_dump_start,
1709 			.dump = ip_set_dump_do,
1710 			.done = ip_set_dump_done,
1711 		};
1712 		return netlink_dump_start(info->sk, skb, info->nlh, &c);
1713 	}
1714 }
1715 
1716 /* Add, del and test */
1717 
1718 static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1719 	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
1720 	[IPSET_ATTR_SETNAME]	= { .type = NLA_NUL_STRING,
1721 				    .len = IPSET_MAXNAMELEN - 1 },
1722 	[IPSET_ATTR_LINENO]	= { .type = NLA_U32 },
1723 	[IPSET_ATTR_DATA]	= { .type = NLA_NESTED },
1724 	[IPSET_ATTR_ADT]	= { .type = NLA_NESTED },
1725 };
1726 
1727 static int
call_ad(struct net * net,struct sock * ctnl,struct sk_buff * skb,struct ip_set * set,struct nlattr * tb[],enum ipset_adt adt,u32 flags,bool use_lineno)1728 call_ad(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1729 	struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt,
1730 	u32 flags, bool use_lineno)
1731 {
1732 	int ret;
1733 	u32 lineno = 0;
1734 	bool eexist = flags & IPSET_FLAG_EXIST, retried = false;
1735 
1736 	do {
1737 		if (retried) {
1738 			__ip_set_get_netlink(set);
1739 			nfnl_unlock(NFNL_SUBSYS_IPSET);
1740 			cond_resched();
1741 			nfnl_lock(NFNL_SUBSYS_IPSET);
1742 			__ip_set_put_netlink(set);
1743 		}
1744 
1745 		ip_set_lock(set);
1746 		ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried);
1747 		ip_set_unlock(set);
1748 		retried = true;
1749 	} while (ret == -ERANGE ||
1750 		 (ret == -EAGAIN &&
1751 		  set->variant->resize &&
1752 		  (ret = set->variant->resize(set, retried)) == 0));
1753 
1754 	if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1755 		return 0;
1756 	if (lineno && use_lineno) {
1757 		/* Error in restore/batch mode: send back lineno */
1758 		struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1759 		struct sk_buff *skb2;
1760 		struct nlmsgerr *errmsg;
1761 		size_t payload = min(SIZE_MAX,
1762 				     sizeof(*errmsg) + nlmsg_len(nlh));
1763 		int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1764 		struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1765 		struct nlattr *cmdattr;
1766 		u32 *errline;
1767 
1768 		skb2 = nlmsg_new(payload, GFP_KERNEL);
1769 		if (!skb2)
1770 			return -ENOMEM;
1771 		rep = nlmsg_put(skb2, NETLINK_CB(skb).portid,
1772 				nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1773 		errmsg = nlmsg_data(rep);
1774 		errmsg->error = ret;
1775 		unsafe_memcpy(&errmsg->msg, nlh, nlh->nlmsg_len,
1776 			      /* Bounds checked by the skb layer. */);
1777 
1778 		cmdattr = (void *)&errmsg->msg + min_len;
1779 
1780 		ret = nla_parse(cda, IPSET_ATTR_CMD_MAX, cmdattr,
1781 				nlh->nlmsg_len - min_len, ip_set_adt_policy,
1782 				NULL);
1783 
1784 		if (ret) {
1785 			nlmsg_free(skb2);
1786 			return ret;
1787 		}
1788 		errline = nla_data(cda[IPSET_ATTR_LINENO]);
1789 
1790 		*errline = lineno;
1791 
1792 		nfnetlink_unicast(skb2, net, NETLINK_CB(skb).portid);
1793 		/* Signal netlink not to send its ACK/errmsg.  */
1794 		return -EINTR;
1795 	}
1796 
1797 	return ret;
1798 }
1799 
ip_set_ad(struct net * net,struct sock * ctnl,struct sk_buff * skb,enum ipset_adt adt,const struct nlmsghdr * nlh,const struct nlattr * const attr[],struct netlink_ext_ack * extack)1800 static int ip_set_ad(struct net *net, struct sock *ctnl,
1801 		     struct sk_buff *skb,
1802 		     enum ipset_adt adt,
1803 		     const struct nlmsghdr *nlh,
1804 		     const struct nlattr * const attr[],
1805 		     struct netlink_ext_ack *extack)
1806 {
1807 	struct ip_set_net *inst = ip_set_pernet(net);
1808 	struct ip_set *set;
1809 	struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1810 	const struct nlattr *nla;
1811 	u32 flags = flag_exist(nlh);
1812 	bool use_lineno;
1813 	int ret = 0;
1814 
1815 	if (unlikely(protocol_min_failed(attr) ||
1816 		     !attr[IPSET_ATTR_SETNAME] ||
1817 		     !((attr[IPSET_ATTR_DATA] != NULL) ^
1818 		       (attr[IPSET_ATTR_ADT] != NULL)) ||
1819 		     (attr[IPSET_ATTR_DATA] &&
1820 		      !flag_nested(attr[IPSET_ATTR_DATA])) ||
1821 		     (attr[IPSET_ATTR_ADT] &&
1822 		      (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1823 		       !attr[IPSET_ATTR_LINENO]))))
1824 		return -IPSET_ERR_PROTOCOL;
1825 
1826 	set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1827 	if (!set)
1828 		return -ENOENT;
1829 
1830 	use_lineno = !!attr[IPSET_ATTR_LINENO];
1831 	if (attr[IPSET_ATTR_DATA]) {
1832 		if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1833 				     attr[IPSET_ATTR_DATA],
1834 				     set->type->adt_policy, NULL))
1835 			return -IPSET_ERR_PROTOCOL;
1836 		ret = call_ad(net, ctnl, skb, set, tb, adt, flags,
1837 			      use_lineno);
1838 	} else {
1839 		int nla_rem;
1840 
1841 		nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1842 			if (nla_type(nla) != IPSET_ATTR_DATA ||
1843 			    !flag_nested(nla) ||
1844 			    nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1845 					     set->type->adt_policy, NULL))
1846 				return -IPSET_ERR_PROTOCOL;
1847 			ret = call_ad(net, ctnl, skb, set, tb, adt,
1848 				      flags, use_lineno);
1849 			if (ret < 0)
1850 				return ret;
1851 		}
1852 	}
1853 	return ret;
1854 }
1855 
ip_set_uadd(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1856 static int ip_set_uadd(struct sk_buff *skb, const struct nfnl_info *info,
1857 		       const struct nlattr * const attr[])
1858 {
1859 	return ip_set_ad(info->net, info->sk, skb,
1860 			 IPSET_ADD, info->nlh, attr, info->extack);
1861 }
1862 
ip_set_udel(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1863 static int ip_set_udel(struct sk_buff *skb, const struct nfnl_info *info,
1864 		       const struct nlattr * const attr[])
1865 {
1866 	return ip_set_ad(info->net, info->sk, skb,
1867 			 IPSET_DEL, info->nlh, attr, info->extack);
1868 }
1869 
ip_set_utest(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1870 static int ip_set_utest(struct sk_buff *skb, const struct nfnl_info *info,
1871 			const struct nlattr * const attr[])
1872 {
1873 	struct ip_set_net *inst = ip_set_pernet(info->net);
1874 	struct ip_set *set;
1875 	struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1876 	int ret = 0;
1877 	u32 lineno;
1878 
1879 	if (unlikely(protocol_min_failed(attr) ||
1880 		     !attr[IPSET_ATTR_SETNAME] ||
1881 		     !attr[IPSET_ATTR_DATA] ||
1882 		     !flag_nested(attr[IPSET_ATTR_DATA])))
1883 		return -IPSET_ERR_PROTOCOL;
1884 
1885 	set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1886 	if (!set)
1887 		return -ENOENT;
1888 
1889 	if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA],
1890 			     set->type->adt_policy, NULL))
1891 		return -IPSET_ERR_PROTOCOL;
1892 
1893 	rcu_read_lock_bh();
1894 	ret = set->variant->uadt(set, tb, IPSET_TEST, &lineno, 0, 0);
1895 	rcu_read_unlock_bh();
1896 	/* Userspace can't trigger element to be re-added */
1897 	if (ret == -EAGAIN)
1898 		ret = 1;
1899 
1900 	return ret > 0 ? 0 : -IPSET_ERR_EXIST;
1901 }
1902 
1903 /* Get headed data of a set */
1904 
ip_set_header(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1905 static int ip_set_header(struct sk_buff *skb, const struct nfnl_info *info,
1906 			 const struct nlattr * const attr[])
1907 {
1908 	struct ip_set_net *inst = ip_set_pernet(info->net);
1909 	const struct ip_set *set;
1910 	struct sk_buff *skb2;
1911 	struct nlmsghdr *nlh2;
1912 
1913 	if (unlikely(protocol_min_failed(attr) ||
1914 		     !attr[IPSET_ATTR_SETNAME]))
1915 		return -IPSET_ERR_PROTOCOL;
1916 
1917 	set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1918 	if (!set)
1919 		return -ENOENT;
1920 
1921 	skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1922 	if (!skb2)
1923 		return -ENOMEM;
1924 
1925 	nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, 0,
1926 			 IPSET_CMD_HEADER);
1927 	if (!nlh2)
1928 		goto nlmsg_failure;
1929 	if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1930 	    nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name) ||
1931 	    nla_put_string(skb2, IPSET_ATTR_TYPENAME, set->type->name) ||
1932 	    nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1933 	    nla_put_u8(skb2, IPSET_ATTR_REVISION, set->revision))
1934 		goto nla_put_failure;
1935 	nlmsg_end(skb2, nlh2);
1936 
1937 	return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid);
1938 
1939 nla_put_failure:
1940 	nlmsg_cancel(skb2, nlh2);
1941 nlmsg_failure:
1942 	kfree_skb(skb2);
1943 	return -EMSGSIZE;
1944 }
1945 
1946 /* Get type data */
1947 
1948 static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1949 	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
1950 	[IPSET_ATTR_TYPENAME]	= { .type = NLA_NUL_STRING,
1951 				    .len = IPSET_MAXNAMELEN - 1 },
1952 	[IPSET_ATTR_FAMILY]	= { .type = NLA_U8 },
1953 };
1954 
ip_set_type(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])1955 static int ip_set_type(struct sk_buff *skb, const struct nfnl_info *info,
1956 		       const struct nlattr * const attr[])
1957 {
1958 	struct sk_buff *skb2;
1959 	struct nlmsghdr *nlh2;
1960 	u8 family, min, max;
1961 	const char *typename;
1962 	int ret = 0;
1963 
1964 	if (unlikely(protocol_min_failed(attr) ||
1965 		     !attr[IPSET_ATTR_TYPENAME] ||
1966 		     !attr[IPSET_ATTR_FAMILY]))
1967 		return -IPSET_ERR_PROTOCOL;
1968 
1969 	family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1970 	typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1971 	ret = find_set_type_minmax(typename, family, &min, &max);
1972 	if (ret)
1973 		return ret;
1974 
1975 	skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1976 	if (!skb2)
1977 		return -ENOMEM;
1978 
1979 	nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, 0,
1980 			 IPSET_CMD_TYPE);
1981 	if (!nlh2)
1982 		goto nlmsg_failure;
1983 	if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1984 	    nla_put_string(skb2, IPSET_ATTR_TYPENAME, typename) ||
1985 	    nla_put_u8(skb2, IPSET_ATTR_FAMILY, family) ||
1986 	    nla_put_u8(skb2, IPSET_ATTR_REVISION, max) ||
1987 	    nla_put_u8(skb2, IPSET_ATTR_REVISION_MIN, min))
1988 		goto nla_put_failure;
1989 	nlmsg_end(skb2, nlh2);
1990 
1991 	pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1992 	return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid);
1993 
1994 nla_put_failure:
1995 	nlmsg_cancel(skb2, nlh2);
1996 nlmsg_failure:
1997 	kfree_skb(skb2);
1998 	return -EMSGSIZE;
1999 }
2000 
2001 /* Get protocol version */
2002 
2003 static const struct nla_policy
2004 ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
2005 	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
2006 };
2007 
ip_set_protocol(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])2008 static int ip_set_protocol(struct sk_buff *skb, const struct nfnl_info *info,
2009 			   const struct nlattr * const attr[])
2010 {
2011 	struct sk_buff *skb2;
2012 	struct nlmsghdr *nlh2;
2013 
2014 	if (unlikely(!attr[IPSET_ATTR_PROTOCOL]))
2015 		return -IPSET_ERR_PROTOCOL;
2016 
2017 	skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
2018 	if (!skb2)
2019 		return -ENOMEM;
2020 
2021 	nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, 0,
2022 			 IPSET_CMD_PROTOCOL);
2023 	if (!nlh2)
2024 		goto nlmsg_failure;
2025 	if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL))
2026 		goto nla_put_failure;
2027 	if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL_MIN, IPSET_PROTOCOL_MIN))
2028 		goto nla_put_failure;
2029 	nlmsg_end(skb2, nlh2);
2030 
2031 	return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid);
2032 
2033 nla_put_failure:
2034 	nlmsg_cancel(skb2, nlh2);
2035 nlmsg_failure:
2036 	kfree_skb(skb2);
2037 	return -EMSGSIZE;
2038 }
2039 
2040 /* Get set by name or index, from userspace */
2041 
ip_set_byname(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])2042 static int ip_set_byname(struct sk_buff *skb, const struct nfnl_info *info,
2043 			 const struct nlattr * const attr[])
2044 {
2045 	struct ip_set_net *inst = ip_set_pernet(info->net);
2046 	struct sk_buff *skb2;
2047 	struct nlmsghdr *nlh2;
2048 	ip_set_id_t id = IPSET_INVALID_ID;
2049 	const struct ip_set *set;
2050 
2051 	if (unlikely(protocol_failed(attr) ||
2052 		     !attr[IPSET_ATTR_SETNAME]))
2053 		return -IPSET_ERR_PROTOCOL;
2054 
2055 	set = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]), &id);
2056 	if (id == IPSET_INVALID_ID)
2057 		return -ENOENT;
2058 
2059 	skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
2060 	if (!skb2)
2061 		return -ENOMEM;
2062 
2063 	nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, 0,
2064 			 IPSET_CMD_GET_BYNAME);
2065 	if (!nlh2)
2066 		goto nlmsg_failure;
2067 	if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
2068 	    nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
2069 	    nla_put_net16(skb2, IPSET_ATTR_INDEX, htons(id)))
2070 		goto nla_put_failure;
2071 	nlmsg_end(skb2, nlh2);
2072 
2073 	return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid);
2074 
2075 nla_put_failure:
2076 	nlmsg_cancel(skb2, nlh2);
2077 nlmsg_failure:
2078 	kfree_skb(skb2);
2079 	return -EMSGSIZE;
2080 }
2081 
2082 static const struct nla_policy ip_set_index_policy[IPSET_ATTR_CMD_MAX + 1] = {
2083 	[IPSET_ATTR_PROTOCOL]	= { .type = NLA_U8 },
2084 	[IPSET_ATTR_INDEX]	= { .type = NLA_U16 },
2085 };
2086 
ip_set_byindex(struct sk_buff * skb,const struct nfnl_info * info,const struct nlattr * const attr[])2087 static int ip_set_byindex(struct sk_buff *skb, const struct nfnl_info *info,
2088 			  const struct nlattr * const attr[])
2089 {
2090 	struct ip_set_net *inst = ip_set_pernet(info->net);
2091 	struct sk_buff *skb2;
2092 	struct nlmsghdr *nlh2;
2093 	ip_set_id_t id = IPSET_INVALID_ID;
2094 	const struct ip_set *set;
2095 
2096 	if (unlikely(protocol_failed(attr) ||
2097 		     !attr[IPSET_ATTR_INDEX]))
2098 		return -IPSET_ERR_PROTOCOL;
2099 
2100 	id = ip_set_get_h16(attr[IPSET_ATTR_INDEX]);
2101 	if (id >= inst->ip_set_max)
2102 		return -ENOENT;
2103 	set = ip_set(inst, id);
2104 	if (set == NULL)
2105 		return -ENOENT;
2106 
2107 	skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
2108 	if (!skb2)
2109 		return -ENOMEM;
2110 
2111 	nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, 0,
2112 			 IPSET_CMD_GET_BYINDEX);
2113 	if (!nlh2)
2114 		goto nlmsg_failure;
2115 	if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
2116 	    nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name))
2117 		goto nla_put_failure;
2118 	nlmsg_end(skb2, nlh2);
2119 
2120 	return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid);
2121 
2122 nla_put_failure:
2123 	nlmsg_cancel(skb2, nlh2);
2124 nlmsg_failure:
2125 	kfree_skb(skb2);
2126 	return -EMSGSIZE;
2127 }
2128 
2129 static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
2130 	[IPSET_CMD_NONE]	= {
2131 		.call		= ip_set_none,
2132 		.type		= NFNL_CB_MUTEX,
2133 		.attr_count	= IPSET_ATTR_CMD_MAX,
2134 	},
2135 	[IPSET_CMD_CREATE]	= {
2136 		.call		= ip_set_create,
2137 		.type		= NFNL_CB_MUTEX,
2138 		.attr_count	= IPSET_ATTR_CMD_MAX,
2139 		.policy		= ip_set_create_policy,
2140 	},
2141 	[IPSET_CMD_DESTROY]	= {
2142 		.call		= ip_set_destroy,
2143 		.type		= NFNL_CB_MUTEX,
2144 		.attr_count	= IPSET_ATTR_CMD_MAX,
2145 		.policy		= ip_set_setname_policy,
2146 	},
2147 	[IPSET_CMD_FLUSH]	= {
2148 		.call		= ip_set_flush,
2149 		.type		= NFNL_CB_MUTEX,
2150 		.attr_count	= IPSET_ATTR_CMD_MAX,
2151 		.policy		= ip_set_setname_policy,
2152 	},
2153 	[IPSET_CMD_RENAME]	= {
2154 		.call		= ip_set_rename,
2155 		.type		= NFNL_CB_MUTEX,
2156 		.attr_count	= IPSET_ATTR_CMD_MAX,
2157 		.policy		= ip_set_setname2_policy,
2158 	},
2159 	[IPSET_CMD_SWAP]	= {
2160 		.call		= ip_set_swap,
2161 		.type		= NFNL_CB_MUTEX,
2162 		.attr_count	= IPSET_ATTR_CMD_MAX,
2163 		.policy		= ip_set_setname2_policy,
2164 	},
2165 	[IPSET_CMD_LIST]	= {
2166 		.call		= ip_set_dump,
2167 		.type		= NFNL_CB_MUTEX,
2168 		.attr_count	= IPSET_ATTR_CMD_MAX,
2169 		.policy		= ip_set_dump_policy,
2170 	},
2171 	[IPSET_CMD_SAVE]	= {
2172 		.call		= ip_set_dump,
2173 		.type		= NFNL_CB_MUTEX,
2174 		.attr_count	= IPSET_ATTR_CMD_MAX,
2175 		.policy		= ip_set_setname_policy,
2176 	},
2177 	[IPSET_CMD_ADD]	= {
2178 		.call		= ip_set_uadd,
2179 		.type		= NFNL_CB_MUTEX,
2180 		.attr_count	= IPSET_ATTR_CMD_MAX,
2181 		.policy		= ip_set_adt_policy,
2182 	},
2183 	[IPSET_CMD_DEL]	= {
2184 		.call		= ip_set_udel,
2185 		.type		= NFNL_CB_MUTEX,
2186 		.attr_count	= IPSET_ATTR_CMD_MAX,
2187 		.policy		= ip_set_adt_policy,
2188 	},
2189 	[IPSET_CMD_TEST]	= {
2190 		.call		= ip_set_utest,
2191 		.type		= NFNL_CB_MUTEX,
2192 		.attr_count	= IPSET_ATTR_CMD_MAX,
2193 		.policy		= ip_set_adt_policy,
2194 	},
2195 	[IPSET_CMD_HEADER]	= {
2196 		.call		= ip_set_header,
2197 		.type		= NFNL_CB_MUTEX,
2198 		.attr_count	= IPSET_ATTR_CMD_MAX,
2199 		.policy		= ip_set_setname_policy,
2200 	},
2201 	[IPSET_CMD_TYPE]	= {
2202 		.call		= ip_set_type,
2203 		.type		= NFNL_CB_MUTEX,
2204 		.attr_count	= IPSET_ATTR_CMD_MAX,
2205 		.policy		= ip_set_type_policy,
2206 	},
2207 	[IPSET_CMD_PROTOCOL]	= {
2208 		.call		= ip_set_protocol,
2209 		.type		= NFNL_CB_MUTEX,
2210 		.attr_count	= IPSET_ATTR_CMD_MAX,
2211 		.policy		= ip_set_protocol_policy,
2212 	},
2213 	[IPSET_CMD_GET_BYNAME]	= {
2214 		.call		= ip_set_byname,
2215 		.type		= NFNL_CB_MUTEX,
2216 		.attr_count	= IPSET_ATTR_CMD_MAX,
2217 		.policy		= ip_set_setname_policy,
2218 	},
2219 	[IPSET_CMD_GET_BYINDEX]	= {
2220 		.call		= ip_set_byindex,
2221 		.type		= NFNL_CB_MUTEX,
2222 		.attr_count	= IPSET_ATTR_CMD_MAX,
2223 		.policy		= ip_set_index_policy,
2224 	},
2225 };
2226 
2227 static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
2228 	.name		= "ip_set",
2229 	.subsys_id	= NFNL_SUBSYS_IPSET,
2230 	.cb_count	= IPSET_MSG_MAX,
2231 	.cb		= ip_set_netlink_subsys_cb,
2232 };
2233 
2234 /* Interface to iptables/ip6tables */
2235 
2236 static int
ip_set_sockfn_get(struct sock * sk,int optval,void __user * user,int * len)2237 ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
2238 {
2239 	unsigned int *op;
2240 	void *data;
2241 	int copylen = *len, ret = 0;
2242 	struct net *net = sock_net(sk);
2243 	struct ip_set_net *inst = ip_set_pernet(net);
2244 
2245 	if (!ns_capable(net->user_ns, CAP_NET_ADMIN))
2246 		return -EPERM;
2247 	if (optval != SO_IP_SET)
2248 		return -EBADF;
2249 	if (*len < sizeof(unsigned int))
2250 		return -EINVAL;
2251 
2252 	data = vmalloc(*len);
2253 	if (!data)
2254 		return -ENOMEM;
2255 	if (copy_from_user(data, user, *len) != 0) {
2256 		ret = -EFAULT;
2257 		goto done;
2258 	}
2259 	op = data;
2260 
2261 	if (*op < IP_SET_OP_VERSION) {
2262 		/* Check the version at the beginning of operations */
2263 		struct ip_set_req_version *req_version = data;
2264 
2265 		if (*len < sizeof(struct ip_set_req_version)) {
2266 			ret = -EINVAL;
2267 			goto done;
2268 		}
2269 
2270 		if (req_version->version < IPSET_PROTOCOL_MIN) {
2271 			ret = -EPROTO;
2272 			goto done;
2273 		}
2274 	}
2275 
2276 	switch (*op) {
2277 	case IP_SET_OP_VERSION: {
2278 		struct ip_set_req_version *req_version = data;
2279 
2280 		if (*len != sizeof(struct ip_set_req_version)) {
2281 			ret = -EINVAL;
2282 			goto done;
2283 		}
2284 
2285 		req_version->version = IPSET_PROTOCOL;
2286 		if (copy_to_user(user, req_version,
2287 				 sizeof(struct ip_set_req_version)))
2288 			ret = -EFAULT;
2289 		goto done;
2290 	}
2291 	case IP_SET_OP_GET_BYNAME: {
2292 		struct ip_set_req_get_set *req_get = data;
2293 		ip_set_id_t id;
2294 
2295 		if (*len != sizeof(struct ip_set_req_get_set)) {
2296 			ret = -EINVAL;
2297 			goto done;
2298 		}
2299 		req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2300 		nfnl_lock(NFNL_SUBSYS_IPSET);
2301 		find_set_and_id(inst, req_get->set.name, &id);
2302 		req_get->set.index = id;
2303 		nfnl_unlock(NFNL_SUBSYS_IPSET);
2304 		goto copy;
2305 	}
2306 	case IP_SET_OP_GET_FNAME: {
2307 		struct ip_set_req_get_set_family *req_get = data;
2308 		ip_set_id_t id;
2309 
2310 		if (*len != sizeof(struct ip_set_req_get_set_family)) {
2311 			ret = -EINVAL;
2312 			goto done;
2313 		}
2314 		req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2315 		nfnl_lock(NFNL_SUBSYS_IPSET);
2316 		find_set_and_id(inst, req_get->set.name, &id);
2317 		req_get->set.index = id;
2318 		if (id != IPSET_INVALID_ID)
2319 			req_get->family = ip_set(inst, id)->family;
2320 		nfnl_unlock(NFNL_SUBSYS_IPSET);
2321 		goto copy;
2322 	}
2323 	case IP_SET_OP_GET_BYINDEX: {
2324 		struct ip_set_req_get_set *req_get = data;
2325 		struct ip_set *set;
2326 
2327 		if (*len != sizeof(struct ip_set_req_get_set) ||
2328 		    req_get->set.index >= inst->ip_set_max) {
2329 			ret = -EINVAL;
2330 			goto done;
2331 		}
2332 		nfnl_lock(NFNL_SUBSYS_IPSET);
2333 		set = ip_set(inst, req_get->set.index);
2334 		ret = strscpy(req_get->set.name, set ? set->name : "",
2335 			      IPSET_MAXNAMELEN);
2336 		nfnl_unlock(NFNL_SUBSYS_IPSET);
2337 		if (ret < 0)
2338 			goto done;
2339 		goto copy;
2340 	}
2341 	default:
2342 		ret = -EBADMSG;
2343 		goto done;
2344 	}	/* end of switch(op) */
2345 
2346 copy:
2347 	if (copy_to_user(user, data, copylen))
2348 		ret = -EFAULT;
2349 
2350 done:
2351 	vfree(data);
2352 	if (ret > 0)
2353 		ret = 0;
2354 	return ret;
2355 }
2356 
2357 static struct nf_sockopt_ops so_set __read_mostly = {
2358 	.pf		= PF_INET,
2359 	.get_optmin	= SO_IP_SET,
2360 	.get_optmax	= SO_IP_SET + 1,
2361 	.get		= ip_set_sockfn_get,
2362 	.owner		= THIS_MODULE,
2363 };
2364 
2365 static int __net_init
ip_set_net_init(struct net * net)2366 ip_set_net_init(struct net *net)
2367 {
2368 	struct ip_set_net *inst = ip_set_pernet(net);
2369 	struct ip_set **list;
2370 
2371 	inst->ip_set_max = max_sets ? max_sets : CONFIG_IP_SET_MAX;
2372 	if (inst->ip_set_max >= IPSET_INVALID_ID)
2373 		inst->ip_set_max = IPSET_INVALID_ID - 1;
2374 
2375 	list = kvcalloc(inst->ip_set_max, sizeof(struct ip_set *), GFP_KERNEL);
2376 	if (!list)
2377 		return -ENOMEM;
2378 	inst->is_deleted = false;
2379 	inst->is_destroyed = false;
2380 	rcu_assign_pointer(inst->ip_set_list, list);
2381 	return 0;
2382 }
2383 
2384 static void __net_exit
ip_set_net_pre_exit(struct net * net)2385 ip_set_net_pre_exit(struct net *net)
2386 {
2387 	struct ip_set_net *inst = ip_set_pernet(net);
2388 
2389 	inst->is_deleted = true; /* flag for ip_set_nfnl_put */
2390 }
2391 
2392 static void __net_exit
ip_set_net_exit(struct net * net)2393 ip_set_net_exit(struct net *net)
2394 {
2395 	struct ip_set_net *inst = ip_set_pernet(net);
2396 
2397 	_destroy_all_sets(inst);
2398 	kvfree(rcu_dereference_protected(inst->ip_set_list, 1));
2399 }
2400 
2401 static struct pernet_operations ip_set_net_ops = {
2402 	.init	= ip_set_net_init,
2403 	.pre_exit = ip_set_net_pre_exit,
2404 	.exit   = ip_set_net_exit,
2405 	.id	= &ip_set_net_id,
2406 	.size	= sizeof(struct ip_set_net),
2407 };
2408 
2409 static int __init
ip_set_init(void)2410 ip_set_init(void)
2411 {
2412 	int ret = register_pernet_subsys(&ip_set_net_ops);
2413 
2414 	if (ret) {
2415 		pr_err("ip_set: cannot register pernet_subsys.\n");
2416 		return ret;
2417 	}
2418 
2419 	ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
2420 	if (ret != 0) {
2421 		pr_err("ip_set: cannot register with nfnetlink.\n");
2422 		unregister_pernet_subsys(&ip_set_net_ops);
2423 		return ret;
2424 	}
2425 
2426 	ret = nf_register_sockopt(&so_set);
2427 	if (ret != 0) {
2428 		pr_err("SO_SET registry failed: %d\n", ret);
2429 		nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2430 		unregister_pernet_subsys(&ip_set_net_ops);
2431 		return ret;
2432 	}
2433 
2434 	return 0;
2435 }
2436 
2437 static void __exit
ip_set_fini(void)2438 ip_set_fini(void)
2439 {
2440 	nf_unregister_sockopt(&so_set);
2441 	nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2442 	unregister_pernet_subsys(&ip_set_net_ops);
2443 
2444 	/* Wait for call_rcu() in destroy */
2445 	rcu_barrier();
2446 
2447 	pr_debug("these are the famous last words\n");
2448 }
2449 
2450 module_init(ip_set_init);
2451 module_exit(ip_set_fini);
2452 
2453 MODULE_DESCRIPTION("ip_set: protocol " __stringify(IPSET_PROTOCOL));
2454