xref: /linux/tools/net/ynl/lib/ynl.c (revision 3d0fe49454652117522f60bfbefb978ba0e5300b)
1 // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2 #include <errno.h>
3 #include <poll.h>
4 #include <string.h>
5 #include <stdlib.h>
6 #include <linux/types.h>
7 
8 #include <libmnl/libmnl.h>
9 #include <linux/genetlink.h>
10 
11 #include "ynl.h"
12 
13 #define ARRAY_SIZE(arr)		(sizeof(arr) / sizeof(*arr))
14 
15 #define __yerr_msg(yse, _msg...)					\
16 	({								\
17 		struct ynl_error *_yse = (yse);				\
18 									\
19 		if (_yse) {						\
20 			snprintf(_yse->msg, sizeof(_yse->msg) - 1,  _msg); \
21 			_yse->msg[sizeof(_yse->msg) - 1] = 0;		\
22 		}							\
23 	})
24 
25 #define __yerr_code(yse, _code...)		\
26 	({					\
27 		struct ynl_error *_yse = (yse);	\
28 						\
29 		if (_yse) {			\
30 			_yse->code = _code;	\
31 		}				\
32 	})
33 
34 #define __yerr(yse, _code, _msg...)		\
35 	({					\
36 		__yerr_msg(yse, _msg);		\
37 		__yerr_code(yse, _code);	\
38 	})
39 
40 #define __perr(yse, _msg)		__yerr(yse, errno, _msg)
41 
42 #define yerr_msg(_ys, _msg...)		__yerr_msg(&(_ys)->err, _msg)
43 #define yerr(_ys, _code, _msg...)	__yerr(&(_ys)->err, _code, _msg)
44 #define perr(_ys, _msg)			__yerr(&(_ys)->err, errno, _msg)
45 
46 /* -- Netlink boiler plate */
47 static int
48 ynl_err_walk_report_one(struct ynl_policy_nest *policy, unsigned int type,
49 			char *str, int str_sz, int *n)
50 {
51 	if (!policy) {
52 		if (*n < str_sz)
53 			*n += snprintf(str, str_sz, "!policy");
54 		return 1;
55 	}
56 
57 	if (type > policy->max_attr) {
58 		if (*n < str_sz)
59 			*n += snprintf(str, str_sz, "!oob");
60 		return 1;
61 	}
62 
63 	if (!policy->table[type].name) {
64 		if (*n < str_sz)
65 			*n += snprintf(str, str_sz, "!name");
66 		return 1;
67 	}
68 
69 	if (*n < str_sz)
70 		*n += snprintf(str, str_sz - *n,
71 			       ".%s", policy->table[type].name);
72 	return 0;
73 }
74 
75 static int
76 ynl_err_walk(struct ynl_sock *ys, void *start, void *end, unsigned int off,
77 	     struct ynl_policy_nest *policy, char *str, int str_sz,
78 	     struct ynl_policy_nest **nest_pol)
79 {
80 	unsigned int astart_off, aend_off;
81 	const struct nlattr *attr;
82 	unsigned int data_len;
83 	unsigned int type;
84 	bool found = false;
85 	int n = 0;
86 
87 	if (!policy) {
88 		if (n < str_sz)
89 			n += snprintf(str, str_sz, "!policy");
90 		return n;
91 	}
92 
93 	data_len = end - start;
94 
95 	mnl_attr_for_each_payload(start, data_len) {
96 		astart_off = (char *)attr - (char *)start;
97 		aend_off = astart_off + mnl_attr_get_payload_len(attr);
98 		if (aend_off <= off)
99 			continue;
100 
101 		found = true;
102 		break;
103 	}
104 	if (!found)
105 		return 0;
106 
107 	off -= astart_off;
108 
109 	type = mnl_attr_get_type(attr);
110 
111 	if (ynl_err_walk_report_one(policy, type, str, str_sz, &n))
112 		return n;
113 
114 	if (!off) {
115 		if (nest_pol)
116 			*nest_pol = policy->table[type].nest;
117 		return n;
118 	}
119 
120 	if (!policy->table[type].nest) {
121 		if (n < str_sz)
122 			n += snprintf(str, str_sz, "!nest");
123 		return n;
124 	}
125 
126 	off -= sizeof(struct nlattr);
127 	start =  mnl_attr_get_payload(attr);
128 	end = start + mnl_attr_get_payload_len(attr);
129 
130 	return n + ynl_err_walk(ys, start, end, off, policy->table[type].nest,
131 				&str[n], str_sz - n, nest_pol);
132 }
133 
134 #define NLMSGERR_ATTR_MISS_TYPE (NLMSGERR_ATTR_POLICY + 1)
135 #define NLMSGERR_ATTR_MISS_NEST (NLMSGERR_ATTR_POLICY + 2)
136 #define NLMSGERR_ATTR_MAX (NLMSGERR_ATTR_MAX + 2)
137 
138 static int
139 ynl_ext_ack_check(struct ynl_sock *ys, const struct nlmsghdr *nlh,
140 		  unsigned int hlen)
141 {
142 	const struct nlattr *tb[NLMSGERR_ATTR_MAX + 1] = {};
143 	char miss_attr[sizeof(ys->err.msg)];
144 	char bad_attr[sizeof(ys->err.msg)];
145 	const struct nlattr *attr;
146 	const char *str = NULL;
147 
148 	if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS))
149 		return MNL_CB_OK;
150 
151 	mnl_attr_for_each(attr, nlh, hlen) {
152 		unsigned int len, type;
153 
154 		len = mnl_attr_get_payload_len(attr);
155 		type = mnl_attr_get_type(attr);
156 
157 		if (type > NLMSGERR_ATTR_MAX)
158 			continue;
159 
160 		tb[type] = attr;
161 
162 		switch (type) {
163 		case NLMSGERR_ATTR_OFFS:
164 		case NLMSGERR_ATTR_MISS_TYPE:
165 		case NLMSGERR_ATTR_MISS_NEST:
166 			if (len != sizeof(__u32))
167 				return MNL_CB_ERROR;
168 			break;
169 		case NLMSGERR_ATTR_MSG:
170 			str = mnl_attr_get_payload(attr);
171 			if (str[len - 1])
172 				return MNL_CB_ERROR;
173 			break;
174 		default:
175 			break;
176 		}
177 	}
178 
179 	bad_attr[0] = '\0';
180 	miss_attr[0] = '\0';
181 
182 	if (tb[NLMSGERR_ATTR_OFFS]) {
183 		unsigned int n, off;
184 		void *start, *end;
185 
186 		ys->err.attr_offs = mnl_attr_get_u32(tb[NLMSGERR_ATTR_OFFS]);
187 
188 		n = snprintf(bad_attr, sizeof(bad_attr), "%sbad attribute: ",
189 			     str ? " (" : "");
190 
191 		start = mnl_nlmsg_get_payload_offset(ys->nlh,
192 						     sizeof(struct genlmsghdr));
193 		end = mnl_nlmsg_get_payload_tail(ys->nlh);
194 
195 		off = ys->err.attr_offs;
196 		off -= sizeof(struct nlmsghdr);
197 		off -= sizeof(struct genlmsghdr);
198 
199 		n += ynl_err_walk(ys, start, end, off, ys->req_policy,
200 				  &bad_attr[n], sizeof(bad_attr) - n, NULL);
201 
202 		if (n >= sizeof(bad_attr))
203 			n = sizeof(bad_attr) - 1;
204 		bad_attr[n] = '\0';
205 	}
206 	if (tb[NLMSGERR_ATTR_MISS_TYPE]) {
207 		struct ynl_policy_nest *nest_pol = NULL;
208 		unsigned int n, off, type;
209 		void *start, *end;
210 		int n2;
211 
212 		type = mnl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_TYPE]);
213 
214 		n = snprintf(miss_attr, sizeof(miss_attr), "%smissing attribute: ",
215 			     bad_attr[0] ? ", " : (str ? " (" : ""));
216 
217 		start = mnl_nlmsg_get_payload_offset(ys->nlh,
218 						     sizeof(struct genlmsghdr));
219 		end = mnl_nlmsg_get_payload_tail(ys->nlh);
220 
221 		nest_pol = ys->req_policy;
222 		if (tb[NLMSGERR_ATTR_MISS_NEST]) {
223 			off = mnl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_NEST]);
224 			off -= sizeof(struct nlmsghdr);
225 			off -= sizeof(struct genlmsghdr);
226 
227 			n += ynl_err_walk(ys, start, end, off, ys->req_policy,
228 					  &miss_attr[n], sizeof(miss_attr) - n,
229 					  &nest_pol);
230 		}
231 
232 		n2 = 0;
233 		ynl_err_walk_report_one(nest_pol, type, &miss_attr[n],
234 					sizeof(miss_attr) - n, &n2);
235 		n += n2;
236 
237 		if (n >= sizeof(miss_attr))
238 			n = sizeof(miss_attr) - 1;
239 		miss_attr[n] = '\0';
240 	}
241 
242 	/* Implicitly depend on ys->err.code already set */
243 	if (str)
244 		yerr_msg(ys, "Kernel %s: '%s'%s%s%s",
245 			 ys->err.code ? "error" : "warning",
246 			 str, bad_attr, miss_attr,
247 			 bad_attr[0] || miss_attr[0] ? ")" : "");
248 	else if (bad_attr[0] || miss_attr[0])
249 		yerr_msg(ys, "Kernel %s: %s%s",
250 			 ys->err.code ? "error" : "warning",
251 			 bad_attr, miss_attr);
252 
253 	return MNL_CB_OK;
254 }
255 
256 static int ynl_cb_error(const struct nlmsghdr *nlh, void *data)
257 {
258 	const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh);
259 	struct ynl_parse_arg *yarg = data;
260 	unsigned int hlen;
261 	int code;
262 
263 	code = err->error >= 0 ? err->error : -err->error;
264 	yarg->ys->err.code = code;
265 	errno = code;
266 
267 	hlen = sizeof(*err);
268 	if (!(nlh->nlmsg_flags & NLM_F_CAPPED))
269 		hlen += mnl_nlmsg_get_payload_len(&err->msg);
270 
271 	ynl_ext_ack_check(yarg->ys, nlh, hlen);
272 
273 	return code ? MNL_CB_ERROR : MNL_CB_STOP;
274 }
275 
276 static int ynl_cb_done(const struct nlmsghdr *nlh, void *data)
277 {
278 	struct ynl_parse_arg *yarg = data;
279 	int err;
280 
281 	err = *(int *)NLMSG_DATA(nlh);
282 	if (err < 0) {
283 		yarg->ys->err.code = -err;
284 		errno = -err;
285 
286 		ynl_ext_ack_check(yarg->ys, nlh, sizeof(int));
287 
288 		return MNL_CB_ERROR;
289 	}
290 	return MNL_CB_STOP;
291 }
292 
293 static int ynl_cb_noop(const struct nlmsghdr *nlh, void *data)
294 {
295 	return MNL_CB_OK;
296 }
297 
298 mnl_cb_t ynl_cb_array[NLMSG_MIN_TYPE] = {
299 	[NLMSG_NOOP]	= ynl_cb_noop,
300 	[NLMSG_ERROR]	= ynl_cb_error,
301 	[NLMSG_DONE]	= ynl_cb_done,
302 	[NLMSG_OVERRUN]	= ynl_cb_noop,
303 };
304 
305 /* Attribute validation */
306 
307 int ynl_attr_validate(struct ynl_parse_arg *yarg, const struct nlattr *attr)
308 {
309 	struct ynl_policy_attr *policy;
310 	unsigned int type, len;
311 	unsigned char *data;
312 
313 	data = mnl_attr_get_payload(attr);
314 	len = mnl_attr_get_payload_len(attr);
315 	type = mnl_attr_get_type(attr);
316 	if (type > yarg->rsp_policy->max_attr) {
317 		yerr(yarg->ys, YNL_ERROR_INTERNAL,
318 		     "Internal error, validating unknown attribute");
319 		return -1;
320 	}
321 
322 	policy = &yarg->rsp_policy->table[type];
323 
324 	switch (policy->type) {
325 	case YNL_PT_REJECT:
326 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
327 		     "Rejected attribute (%s)", policy->name);
328 		return -1;
329 	case YNL_PT_IGNORE:
330 		break;
331 	case YNL_PT_U8:
332 		if (len == sizeof(__u8))
333 			break;
334 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
335 		     "Invalid attribute (u8 %s)", policy->name);
336 		return -1;
337 	case YNL_PT_U16:
338 		if (len == sizeof(__u16))
339 			break;
340 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
341 		     "Invalid attribute (u16 %s)", policy->name);
342 		return -1;
343 	case YNL_PT_U32:
344 		if (len == sizeof(__u32))
345 			break;
346 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
347 		     "Invalid attribute (u32 %s)", policy->name);
348 		return -1;
349 	case YNL_PT_U64:
350 		if (len == sizeof(__u64))
351 			break;
352 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
353 		     "Invalid attribute (u64 %s)", policy->name);
354 		return -1;
355 	case YNL_PT_UINT:
356 		if (len == sizeof(__u32) || len == sizeof(__u64))
357 			break;
358 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
359 		     "Invalid attribute (uint %s)", policy->name);
360 		return -1;
361 	case YNL_PT_FLAG:
362 		/* Let flags grow into real attrs, why not.. */
363 		break;
364 	case YNL_PT_NEST:
365 		if (!len || len >= sizeof(*attr))
366 			break;
367 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
368 		     "Invalid attribute (nest %s)", policy->name);
369 		return -1;
370 	case YNL_PT_BINARY:
371 		if (!policy->len || len == policy->len)
372 			break;
373 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
374 		     "Invalid attribute (binary %s)", policy->name);
375 		return -1;
376 	case YNL_PT_NUL_STR:
377 		if ((!policy->len || len <= policy->len) && !data[len - 1])
378 			break;
379 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
380 		     "Invalid attribute (string %s)", policy->name);
381 		return -1;
382 	case YNL_PT_BITFIELD32:
383 		if (len == sizeof(struct nla_bitfield32))
384 			break;
385 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
386 		     "Invalid attribute (bitfield32 %s)", policy->name);
387 		return -1;
388 	default:
389 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
390 		     "Invalid attribute (unknown %s)", policy->name);
391 		return -1;
392 	}
393 
394 	return 0;
395 }
396 
397 /* Generic code */
398 
399 static void ynl_err_reset(struct ynl_sock *ys)
400 {
401 	ys->err.code = 0;
402 	ys->err.attr_offs = 0;
403 	ys->err.msg[0] = 0;
404 }
405 
406 struct nlmsghdr *ynl_msg_start(struct ynl_sock *ys, __u32 id, __u16 flags)
407 {
408 	struct nlmsghdr *nlh;
409 
410 	ynl_err_reset(ys);
411 
412 	nlh = ys->nlh = mnl_nlmsg_put_header(ys->tx_buf);
413 	nlh->nlmsg_type	= id;
414 	nlh->nlmsg_flags = flags;
415 	nlh->nlmsg_seq = ++ys->seq;
416 
417 	return nlh;
418 }
419 
420 struct nlmsghdr *
421 ynl_gemsg_start(struct ynl_sock *ys, __u32 id, __u16 flags,
422 		__u8 cmd, __u8 version)
423 {
424 	struct genlmsghdr gehdr;
425 	struct nlmsghdr *nlh;
426 	void *data;
427 
428 	nlh = ynl_msg_start(ys, id, flags);
429 
430 	memset(&gehdr, 0, sizeof(gehdr));
431 	gehdr.cmd = cmd;
432 	gehdr.version = version;
433 
434 	data = mnl_nlmsg_put_extra_header(nlh, sizeof(gehdr));
435 	memcpy(data, &gehdr, sizeof(gehdr));
436 
437 	return nlh;
438 }
439 
440 void ynl_msg_start_req(struct ynl_sock *ys, __u32 id)
441 {
442 	ynl_msg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK);
443 }
444 
445 void ynl_msg_start_dump(struct ynl_sock *ys, __u32 id)
446 {
447 	ynl_msg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP);
448 }
449 
450 struct nlmsghdr *
451 ynl_gemsg_start_req(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version)
452 {
453 	return ynl_gemsg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK, cmd, version);
454 }
455 
456 struct nlmsghdr *
457 ynl_gemsg_start_dump(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version)
458 {
459 	return ynl_gemsg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP,
460 			       cmd, version);
461 }
462 
463 int ynl_recv_ack(struct ynl_sock *ys, int ret)
464 {
465 	if (!ret) {
466 		yerr(ys, YNL_ERROR_EXPECT_ACK,
467 		     "Expecting an ACK but nothing received");
468 		return -1;
469 	}
470 
471 	ret = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);
472 	if (ret < 0) {
473 		perr(ys, "Socket receive failed");
474 		return ret;
475 	}
476 	return mnl_cb_run(ys->rx_buf, ret, ys->seq, ys->portid,
477 			  ynl_cb_null, ys);
478 }
479 
480 int ynl_cb_null(const struct nlmsghdr *nlh, void *data)
481 {
482 	struct ynl_parse_arg *yarg = data;
483 
484 	yerr(yarg->ys, YNL_ERROR_UNEXPECT_MSG,
485 	     "Received a message when none were expected");
486 
487 	return MNL_CB_ERROR;
488 }
489 
490 /* Init/fini and genetlink boiler plate */
491 static int
492 ynl_get_family_info_mcast(struct ynl_sock *ys, const struct nlattr *mcasts)
493 {
494 	const struct nlattr *entry, *attr;
495 	unsigned int i;
496 
497 	mnl_attr_for_each_nested(attr, mcasts)
498 		ys->n_mcast_groups++;
499 
500 	if (!ys->n_mcast_groups)
501 		return 0;
502 
503 	ys->mcast_groups = calloc(ys->n_mcast_groups,
504 				  sizeof(*ys->mcast_groups));
505 	if (!ys->mcast_groups)
506 		return MNL_CB_ERROR;
507 
508 	i = 0;
509 	mnl_attr_for_each_nested(entry, mcasts) {
510 		mnl_attr_for_each_nested(attr, entry) {
511 			if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GRP_ID)
512 				ys->mcast_groups[i].id = mnl_attr_get_u32(attr);
513 			if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GRP_NAME) {
514 				strncpy(ys->mcast_groups[i].name,
515 					mnl_attr_get_str(attr),
516 					GENL_NAMSIZ - 1);
517 				ys->mcast_groups[i].name[GENL_NAMSIZ - 1] = 0;
518 			}
519 		}
520 	}
521 
522 	return 0;
523 }
524 
525 static int ynl_get_family_info_cb(const struct nlmsghdr *nlh, void *data)
526 {
527 	struct ynl_parse_arg *yarg = data;
528 	struct ynl_sock *ys = yarg->ys;
529 	const struct nlattr *attr;
530 	bool found_id = true;
531 
532 	mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr)) {
533 		if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GROUPS)
534 			if (ynl_get_family_info_mcast(ys, attr))
535 				return MNL_CB_ERROR;
536 
537 		if (mnl_attr_get_type(attr) != CTRL_ATTR_FAMILY_ID)
538 			continue;
539 
540 		if (mnl_attr_get_payload_len(attr) != sizeof(__u16)) {
541 			yerr(ys, YNL_ERROR_ATTR_INVALID, "Invalid family ID");
542 			return MNL_CB_ERROR;
543 		}
544 
545 		ys->family_id = mnl_attr_get_u16(attr);
546 		found_id = true;
547 	}
548 
549 	if (!found_id) {
550 		yerr(ys, YNL_ERROR_ATTR_MISSING, "Family ID missing");
551 		return MNL_CB_ERROR;
552 	}
553 	return MNL_CB_OK;
554 }
555 
556 static int ynl_sock_read_family(struct ynl_sock *ys, const char *family_name)
557 {
558 	struct ynl_parse_arg yarg = { .ys = ys, };
559 	struct nlmsghdr *nlh;
560 	int err;
561 
562 	nlh = ynl_gemsg_start_req(ys, GENL_ID_CTRL, CTRL_CMD_GETFAMILY, 1);
563 	mnl_attr_put_strz(nlh, CTRL_ATTR_FAMILY_NAME, family_name);
564 
565 	err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);
566 	if (err < 0) {
567 		perr(ys, "failed to request socket family info");
568 		return err;
569 	}
570 
571 	err = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);
572 	if (err <= 0) {
573 		perr(ys, "failed to receive the socket family info");
574 		return err;
575 	}
576 	err = mnl_cb_run2(ys->rx_buf, err, ys->seq, ys->portid,
577 			  ynl_get_family_info_cb, &yarg,
578 			  ynl_cb_array, ARRAY_SIZE(ynl_cb_array));
579 	if (err < 0) {
580 		free(ys->mcast_groups);
581 		perr(ys, "failed to receive the socket family info - no such family?");
582 		return err;
583 	}
584 
585 	return ynl_recv_ack(ys, err);
586 }
587 
588 struct ynl_sock *
589 ynl_sock_create(const struct ynl_family *yf, struct ynl_error *yse)
590 {
591 	struct ynl_sock *ys;
592 	int one = 1;
593 
594 	ys = malloc(sizeof(*ys) + 2 * MNL_SOCKET_BUFFER_SIZE);
595 	if (!ys)
596 		return NULL;
597 	memset(ys, 0, sizeof(*ys));
598 
599 	ys->family = yf;
600 	ys->tx_buf = &ys->raw_buf[0];
601 	ys->rx_buf = &ys->raw_buf[MNL_SOCKET_BUFFER_SIZE];
602 	ys->ntf_last_next = &ys->ntf_first;
603 
604 	ys->sock = mnl_socket_open(NETLINK_GENERIC);
605 	if (!ys->sock) {
606 		__perr(yse, "failed to create a netlink socket");
607 		goto err_free_sock;
608 	}
609 
610 	if (mnl_socket_setsockopt(ys->sock, NETLINK_CAP_ACK,
611 				  &one, sizeof(one))) {
612 		__perr(yse, "failed to enable netlink ACK");
613 		goto err_close_sock;
614 	}
615 	if (mnl_socket_setsockopt(ys->sock, NETLINK_EXT_ACK,
616 				  &one, sizeof(one))) {
617 		__perr(yse, "failed to enable netlink ext ACK");
618 		goto err_close_sock;
619 	}
620 
621 	ys->seq = random();
622 	ys->portid = mnl_socket_get_portid(ys->sock);
623 
624 	if (ynl_sock_read_family(ys, yf->name)) {
625 		if (yse)
626 			memcpy(yse, &ys->err, sizeof(*yse));
627 		goto err_close_sock;
628 	}
629 
630 	return ys;
631 
632 err_close_sock:
633 	mnl_socket_close(ys->sock);
634 err_free_sock:
635 	free(ys);
636 	return NULL;
637 }
638 
639 void ynl_sock_destroy(struct ynl_sock *ys)
640 {
641 	struct ynl_ntf_base_type *ntf;
642 
643 	mnl_socket_close(ys->sock);
644 	while ((ntf = ynl_ntf_dequeue(ys)))
645 		ynl_ntf_free(ntf);
646 	free(ys->mcast_groups);
647 	free(ys);
648 }
649 
650 /* YNL multicast handling */
651 
652 void ynl_ntf_free(struct ynl_ntf_base_type *ntf)
653 {
654 	ntf->free(ntf);
655 }
656 
657 int ynl_subscribe(struct ynl_sock *ys, const char *grp_name)
658 {
659 	unsigned int i;
660 	int err;
661 
662 	for (i = 0; i < ys->n_mcast_groups; i++)
663 		if (!strcmp(ys->mcast_groups[i].name, grp_name))
664 			break;
665 	if (i == ys->n_mcast_groups) {
666 		yerr(ys, ENOENT, "Multicast group '%s' not found", grp_name);
667 		return -1;
668 	}
669 
670 	err = mnl_socket_setsockopt(ys->sock, NETLINK_ADD_MEMBERSHIP,
671 				    &ys->mcast_groups[i].id,
672 				    sizeof(ys->mcast_groups[i].id));
673 	if (err < 0) {
674 		perr(ys, "Subscribing to multicast group failed");
675 		return -1;
676 	}
677 
678 	return 0;
679 }
680 
681 int ynl_socket_get_fd(struct ynl_sock *ys)
682 {
683 	return mnl_socket_get_fd(ys->sock);
684 }
685 
686 struct ynl_ntf_base_type *ynl_ntf_dequeue(struct ynl_sock *ys)
687 {
688 	struct ynl_ntf_base_type *ntf;
689 
690 	if (!ynl_has_ntf(ys))
691 		return NULL;
692 
693 	ntf = ys->ntf_first;
694 	ys->ntf_first = ntf->next;
695 	if (ys->ntf_last_next == &ntf->next)
696 		ys->ntf_last_next = &ys->ntf_first;
697 
698 	return ntf;
699 }
700 
701 static int ynl_ntf_parse(struct ynl_sock *ys, const struct nlmsghdr *nlh)
702 {
703 	struct ynl_parse_arg yarg = { .ys = ys, };
704 	const struct ynl_ntf_info *info;
705 	struct ynl_ntf_base_type *rsp;
706 	struct genlmsghdr *gehdr;
707 	int ret;
708 
709 	gehdr = mnl_nlmsg_get_payload(nlh);
710 	if (gehdr->cmd >= ys->family->ntf_info_size)
711 		return MNL_CB_ERROR;
712 	info = &ys->family->ntf_info[gehdr->cmd];
713 	if (!info->cb)
714 		return MNL_CB_ERROR;
715 
716 	rsp = calloc(1, info->alloc_sz);
717 	rsp->free = info->free;
718 	yarg.data = rsp->data;
719 	yarg.rsp_policy = info->policy;
720 
721 	ret = info->cb(nlh, &yarg);
722 	if (ret <= MNL_CB_STOP)
723 		goto err_free;
724 
725 	rsp->family = nlh->nlmsg_type;
726 	rsp->cmd = gehdr->cmd;
727 
728 	*ys->ntf_last_next = rsp;
729 	ys->ntf_last_next = &rsp->next;
730 
731 	return MNL_CB_OK;
732 
733 err_free:
734 	info->free(rsp);
735 	return MNL_CB_ERROR;
736 }
737 
738 static int ynl_ntf_trampoline(const struct nlmsghdr *nlh, void *data)
739 {
740 	return ynl_ntf_parse((struct ynl_sock *)data, nlh);
741 }
742 
743 int ynl_ntf_check(struct ynl_sock *ys)
744 {
745 	ssize_t len;
746 	int err;
747 
748 	do {
749 		/* libmnl doesn't let us pass flags to the recv to make
750 		 * it non-blocking so we need to poll() or peek() :|
751 		 */
752 		struct pollfd pfd = { };
753 
754 		pfd.fd = mnl_socket_get_fd(ys->sock);
755 		pfd.events = POLLIN;
756 		err = poll(&pfd, 1, 1);
757 		if (err < 1)
758 			return err;
759 
760 		len = mnl_socket_recvfrom(ys->sock, ys->rx_buf,
761 					  MNL_SOCKET_BUFFER_SIZE);
762 		if (len < 0)
763 			return len;
764 
765 		err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,
766 				  ynl_ntf_trampoline, ys,
767 				  ynl_cb_array, NLMSG_MIN_TYPE);
768 		if (err < 0)
769 			return err;
770 	} while (err > 0);
771 
772 	return 0;
773 }
774 
775 /* YNL specific helpers used by the auto-generated code */
776 
777 struct ynl_dump_list_type *YNL_LIST_END = (void *)(0xb4d123);
778 
779 void ynl_error_unknown_notification(struct ynl_sock *ys, __u8 cmd)
780 {
781 	yerr(ys, YNL_ERROR_UNKNOWN_NTF,
782 	     "Unknown notification message type '%d'", cmd);
783 }
784 
785 int ynl_error_parse(struct ynl_parse_arg *yarg, const char *msg)
786 {
787 	yerr(yarg->ys, YNL_ERROR_INV_RESP, "Error parsing response: %s", msg);
788 	return MNL_CB_ERROR;
789 }
790 
791 static int
792 ynl_check_alien(struct ynl_sock *ys, const struct nlmsghdr *nlh, __u32 rsp_cmd)
793 {
794 	struct genlmsghdr *gehdr;
795 
796 	if (mnl_nlmsg_get_payload_len(nlh) < sizeof(*gehdr)) {
797 		yerr(ys, YNL_ERROR_INV_RESP,
798 		     "Kernel responded with truncated message");
799 		return -1;
800 	}
801 
802 	gehdr = mnl_nlmsg_get_payload(nlh);
803 	if (gehdr->cmd != rsp_cmd)
804 		return ynl_ntf_parse(ys, nlh);
805 
806 	return 0;
807 }
808 
809 static int ynl_req_trampoline(const struct nlmsghdr *nlh, void *data)
810 {
811 	struct ynl_req_state *yrs = data;
812 	int ret;
813 
814 	ret = ynl_check_alien(yrs->yarg.ys, nlh, yrs->rsp_cmd);
815 	if (ret)
816 		return ret < 0 ? MNL_CB_ERROR : MNL_CB_OK;
817 
818 	return yrs->cb(nlh, &yrs->yarg);
819 }
820 
821 int ynl_exec(struct ynl_sock *ys, struct nlmsghdr *req_nlh,
822 	     struct ynl_req_state *yrs)
823 {
824 	ssize_t len;
825 	int err;
826 
827 	err = mnl_socket_sendto(ys->sock, req_nlh, req_nlh->nlmsg_len);
828 	if (err < 0)
829 		return err;
830 
831 	do {
832 		len = mnl_socket_recvfrom(ys->sock, ys->rx_buf,
833 					  MNL_SOCKET_BUFFER_SIZE);
834 		if (len < 0)
835 			return len;
836 
837 		err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,
838 				  ynl_req_trampoline, yrs,
839 				  ynl_cb_array, NLMSG_MIN_TYPE);
840 		if (err < 0)
841 			return err;
842 	} while (err > 0);
843 
844 	return 0;
845 }
846 
847 static int ynl_dump_trampoline(const struct nlmsghdr *nlh, void *data)
848 {
849 	struct ynl_dump_state *ds = data;
850 	struct ynl_dump_list_type *obj;
851 	struct ynl_parse_arg yarg = {};
852 	int ret;
853 
854 	ret = ynl_check_alien(ds->ys, nlh, ds->rsp_cmd);
855 	if (ret)
856 		return ret < 0 ? MNL_CB_ERROR : MNL_CB_OK;
857 
858 	obj = calloc(1, ds->alloc_sz);
859 	if (!obj)
860 		return MNL_CB_ERROR;
861 
862 	if (!ds->first)
863 		ds->first = obj;
864 	if (ds->last)
865 		ds->last->next = obj;
866 	ds->last = obj;
867 
868 	yarg.ys = ds->ys;
869 	yarg.rsp_policy = ds->rsp_policy;
870 	yarg.data = &obj->data;
871 
872 	return ds->cb(nlh, &yarg);
873 }
874 
875 static void *ynl_dump_end(struct ynl_dump_state *ds)
876 {
877 	if (!ds->first)
878 		return YNL_LIST_END;
879 
880 	ds->last->next = YNL_LIST_END;
881 	return ds->first;
882 }
883 
884 int ynl_exec_dump(struct ynl_sock *ys, struct nlmsghdr *req_nlh,
885 		  struct ynl_dump_state *yds)
886 {
887 	ssize_t len;
888 	int err;
889 
890 	err = mnl_socket_sendto(ys->sock, req_nlh, req_nlh->nlmsg_len);
891 	if (err < 0)
892 		return err;
893 
894 	do {
895 		len = mnl_socket_recvfrom(ys->sock, ys->rx_buf,
896 					  MNL_SOCKET_BUFFER_SIZE);
897 		if (len < 0)
898 			goto err_close_list;
899 
900 		err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,
901 				  ynl_dump_trampoline, yds,
902 				  ynl_cb_array, NLMSG_MIN_TYPE);
903 		if (err < 0)
904 			goto err_close_list;
905 	} while (err > 0);
906 
907 	yds->first = ynl_dump_end(yds);
908 	return 0;
909 
910 err_close_list:
911 	yds->first = ynl_dump_end(yds);
912 	return -1;
913 }
914