xref: /freebsd/sys/rpc/clnt_nl.c (revision 936f1765b2b349809591afb8209b022c9af61bd7)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2025 Gleb Smirnoff <glebius@FreeBSD.org>
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25  * SUCH DAMAGE.
26  */
27 
28 #include <sys/param.h>
29 #include <sys/lock.h>
30 #include <sys/kernel.h>
31 #include <sys/malloc.h>
32 #include <sys/mutex.h>
33 #include <sys/rwlock.h>
34 #include <sys/mbuf.h>
35 #include <sys/priv.h>
36 #include <sys/proc.h>
37 #include <sys/queue.h>
38 #include <sys/tree.h>
39 
40 #include <rpc/rpc.h>
41 #include <rpc/rpc_com.h>
42 #include <rpc/krpc.h>
43 #include <rpc/clnt_nl.h>
44 
45 #include <netlink/netlink.h>
46 #include <netlink/netlink_ctl.h>
47 #include <netlink/netlink_generic.h>
48 
49 /*
50  * Kernel RPC client over netlink(4), where kernel is RPC client and an
51  * application is a server.  See svc_nl.c in the libc/rpc as the counterpart.
52  *
53  * The module registers itself within generic netlink families list under name
54  * "rpc".  Every new client creates a new multicast group belonging to this
55  * family.  When a client starts RPC, the module will multicast the call to
56  * potential netlink listeners and sleep/retry until receiving a result.  The
57  * framing of the request:
58  *
59  * [netlink message header, type = "rpc" ID, seq == xid]
60  * [generic netlink header, cmd = RPCNL_REQUEST]
61  * [netlink attribute RPCNL_REQUEST_GROUP]
62  * [group ID]
63  * [netlink attribute RPCNL_REQUEST_BODY]
64  * [XDR encoded payload]
65  *
66  * Note: the generic netlink header and attributes aren't really necessary
67  * for successful communication, since the netlink multicast membership already
68  * guarantees us all needed filtering.  The working prototype was putting the
69  * XDR encoded payload right after netlink message header.  But we will provide
70  * this framing to allow for any future extensions.
71  *
72  * The expected RPC result from the userland shall be framed like this:
73  *
74  * [netlink message header, type = "rpc" ID, seq == xid]
75  * [generic netlink header, cmd = RPCNL_REPLY]
76  * [netlink attribute RPCNL_REPLY_GROUP]
77  * [group ID]
78  * [netlink attribute RPCNL_REPLY_BODY]
79  * [XDR encoded payload]
80  *
81  * Disclaimer: has been designed and tested only for the NFS related kernel
82  * RPC clients: kgssapi, RPC binding for NLM, TLS client and TLS server.
83  *
84  * Caveats:
85  * 1) Now the privilege checking is hardcoded to PRIV_NFS_DAEMON at the netlink
86  *    command and multicast layers.  If any new client in addition to NFS
87  *    service emerges, we may want to rewrite privelege checking at the client
88  *    level somehow.
89  * 2) Since we are using netlink attribute for the payload, payload size is
90  *    limited to UINT16_MAX.  Today it is smaller than RPC_MAXDATASIZE of 9000.
91  *    What if a future RPC wants more?
92  */
93 
94 static enum clnt_stat clnt_nl_call(CLIENT *, struct rpc_callextra *,
95     rpcproc_t, struct mbuf *, struct mbuf **, struct timeval);
96 static void clnt_nl_close(CLIENT *);
97 static void clnt_nl_destroy(CLIENT *);
98 static bool_t clnt_nl_control(CLIENT *, u_int, void *);
99 
100 static const struct clnt_ops clnt_nl_ops = {
101 	.cl_call =	clnt_nl_call,
102 	.cl_close =	clnt_nl_close,
103 	.cl_destroy =	clnt_nl_destroy,
104 	.cl_control =	clnt_nl_control,
105 };
106 
107 static int clnt_nl_reply(struct nlmsghdr *, struct nl_pstate *);
108 
109 static const struct genl_cmd clnt_cmds[] = {
110 	{
111 		.cmd_num = RPCNL_REPLY,
112 		.cmd_name = "request",
113 		.cmd_cb = clnt_nl_reply,
114 		.cmd_priv = PRIV_NFS_DAEMON,
115 	},
116 };
117 
118 struct nl_reply_parsed {
119 	uint32_t	group;
120 	struct nlattr	*data;
121 };
122 static const struct nlattr_parser rpcnl_attr_parser[] = {
123 #define	OUT(field)	offsetof(struct nl_reply_parsed, field)
124     { .type = RPCNL_REPLY_GROUP, .off = OUT(group), .cb = nlattr_get_uint32 },
125     { .type = RPCNL_REPLY_BODY, .off = OUT(data), .cb = nlattr_get_nla },
126 #undef OUT
127 };
128 NL_DECLARE_PARSER(rpcnl_parser, struct genlmsghdr, nlf_p_empty,
129     rpcnl_attr_parser);
130 
131 struct nl_data {
132 	struct mtx	nl_lock;
133 	RB_ENTRY(nl_data) nl_tree;
134 	TAILQ_HEAD(, ct_request) nl_pending;
135 	uint32_t	nl_xid;
136 	u_int		nl_mpos;
137 	u_int		nl_authlen;
138 	u_int		nl_retries;
139 	struct {
140 		struct genlmsghdr ghdr;
141 		struct nlattr gattr;
142 		uint32_t group;
143 	} nl_hdr;	/* pre-initialized header */
144 	char		nl_mcallc[MCALL_MSG_SIZE]; /* marshalled callmsg */
145 	/* msleep(9) arguments */
146 	const char *	nl_wchan;
147 	int		nl_prio;
148 	int		nl_timo;
149 };
150 
151 static RB_HEAD(nl_data_t, nl_data) rpcnl_clients;
152 static int32_t
nl_data_compare(const struct nl_data * a,const struct nl_data * b)153 nl_data_compare(const struct nl_data *a, const struct nl_data *b)
154 {
155 	return ((int32_t)(a->nl_hdr.group - b->nl_hdr.group));
156 }
157 RB_GENERATE_STATIC(nl_data_t, nl_data, nl_tree, nl_data_compare);
158 static struct rwlock rpcnl_global_lock;
159 
160 static const char rpcnl_family_name[] = "rpc";
161 static uint16_t rpcnl_family_id;
162 
163 void
rpcnl_init(void)164 rpcnl_init(void)
165 {
166 	bool rv __diagused;
167 
168 	rpcnl_family_id = genl_register_family(rpcnl_family_name, 0, 1, 1);
169 	MPASS(rpcnl_family_id != 0);
170 	rv = genl_register_cmds(rpcnl_family_id, clnt_cmds, nitems(clnt_cmds));
171 	MPASS(rv);
172 	rw_init(&rpcnl_global_lock, rpcnl_family_name);
173 }
174 
175 CLIENT *
client_nl_create(const char * name,const rpcprog_t program,const rpcvers_t version)176 client_nl_create(const char *name, const rpcprog_t program,
177     const rpcvers_t version)
178 {
179 	CLIENT *cl;
180 	struct nl_data *nl;
181 	struct timeval now;
182 	struct rpc_msg call_msg;
183 	XDR xdrs;
184 	uint32_t group;
185 	bool rv __diagused;
186 
187 	if ((group = genl_register_group(rpcnl_family_id, name)) == 0)
188 		return (NULL);
189 
190 	nl = malloc(sizeof(*nl), M_RPC, M_WAITOK);
191 	*nl = (struct nl_data){
192 		.nl_pending = TAILQ_HEAD_INITIALIZER(nl->nl_pending),
193 		.nl_hdr = {
194 			.ghdr.cmd = RPCNL_REQUEST,
195 			.gattr.nla_type = RPCNL_REQUEST_GROUP,
196 			.gattr.nla_len = sizeof(struct nlattr) +
197 			    sizeof(uint32_t),
198 			.group = group,
199 		},
200 		.nl_wchan = rpcnl_family_name,
201 		.nl_prio = PSOCK | PCATCH,
202 		.nl_timo = 60 * hz,
203 		.nl_retries = 1,
204 	};
205 	mtx_init(&nl->nl_lock, "rpc_clnt_nl", NULL, MTX_DEF);
206 
207 	/*
208 	 * Initialize and pre-serialize the static part of the call message.
209 	 */
210 	getmicrotime(&now);
211 	nl->nl_xid = __RPC_GETXID(&now);
212 	call_msg = (struct rpc_msg ){
213 		.rm_xid = nl->nl_xid,
214 		.rm_direction = CALL,
215 		.rm_call = {
216 			.cb_rpcvers = RPC_MSG_VERSION,
217 			.cb_prog = (uint32_t)program,
218 			.cb_vers = (uint32_t)version,
219 		},
220 	};
221 
222 	cl = malloc(sizeof(*cl), M_RPC, M_WAITOK);
223 	*cl = (CLIENT){
224 		.cl_refs = 1,
225 		.cl_ops = &clnt_nl_ops,
226 		.cl_private = nl,
227 		.cl_auth = authnone_create(),
228 	};
229 
230 	/*
231 	 * Experimentally learn how many bytes does procedure name plus
232 	 * authnone header needs.  Use nl_mcallc as temporary scratch space.
233 	 */
234 	xdrmem_create(&xdrs, nl->nl_mcallc, MCALL_MSG_SIZE, XDR_ENCODE);
235 	rv = xdr_putint32(&xdrs, &(rpcproc_t){0});
236 	MPASS(rv);
237 	rv = AUTH_MARSHALL(cl->cl_auth, 0, &xdrs, NULL);
238 	MPASS(rv);
239 	nl->nl_authlen = xdr_getpos(&xdrs);
240 	xdr_destroy(&xdrs);
241 
242 	xdrmem_create(&xdrs, nl->nl_mcallc, MCALL_MSG_SIZE, XDR_ENCODE);
243 	rv = xdr_callhdr(&xdrs, &call_msg);
244 	MPASS(rv);
245 	nl->nl_mpos = xdr_getpos(&xdrs);
246 	xdr_destroy(&xdrs);
247 
248 	rw_wlock(&rpcnl_global_lock);
249 	RB_INSERT(nl_data_t, &rpcnl_clients, nl);
250 	rw_wunlock(&rpcnl_global_lock);
251 
252 	return (cl);
253 }
254 
255 static enum clnt_stat
clnt_nl_call(CLIENT * cl,struct rpc_callextra * ext,rpcproc_t proc,struct mbuf * args,struct mbuf ** resultsp,struct timeval utimeout)256 clnt_nl_call(CLIENT *cl, struct rpc_callextra *ext, rpcproc_t proc,
257     struct mbuf *args, struct mbuf **resultsp, struct timeval utimeout)
258 {
259 	struct nl_writer nw;
260 	struct nl_data *nl = cl->cl_private;
261 	struct ct_request *cr;
262 	struct rpc_err *errp, err;
263 	enum clnt_stat stat;
264 	AUTH *auth;
265 	XDR xdrs;
266 	void *mem;
267 	uint32_t len, xlen;
268 	u_int retries = 0;
269 	bool rv __diagused;
270 
271 	CURVNET_ASSERT_SET();
272 
273 	cr = malloc(sizeof(struct ct_request), M_RPC, M_WAITOK);
274 	*cr = (struct ct_request){
275 		.cr_xid = atomic_fetchadd_32(&nl->nl_xid, 1),
276 		.cr_error = ETIMEDOUT,
277 #ifdef VIMAGE
278 		.cr_vnet = curvnet,
279 #endif
280 	};
281 
282 	if (ext) {
283 		auth = ext->rc_auth;
284 		errp = &ext->rc_err;
285 		len = RPC_MAXDATASIZE;	/* XXXGL: can be improved */
286 	} else {
287 		auth = cl->cl_auth;
288 		errp = &err;
289 		len = nl->nl_mpos + nl->nl_authlen + m_length(args, NULL);
290 	}
291 
292 	mem = malloc(len, M_RPC, M_WAITOK);
293 retry:
294 	xdrmem_create(&xdrs, mem, len, XDR_ENCODE);
295 
296 	rv = xdr_putbytes(&xdrs, nl->nl_mcallc, nl->nl_mpos);
297 	MPASS(rv);
298 	rv = xdr_putint32(&xdrs, &proc);
299 	MPASS(rv);
300 	if (!AUTH_MARSHALL(auth, cr->cr_xid, &xdrs, args)) {
301 		stat = errp->re_status = RPC_CANTENCODEARGS;
302 		goto out;
303 	} else
304 		stat = errp->re_status = RPC_SUCCESS;
305 
306 	/* XXX: XID is the first thing in the request. */
307 	*(uint32_t *)mem = htonl(cr->cr_xid);
308 
309 	xlen = xdr_getpos(&xdrs);
310 	rv = nl_writer_group(&nw, xlen, NETLINK_GENERIC, nl->nl_hdr.group,
311 	    PRIV_NFS_DAEMON, true);
312 	MPASS(rv);
313 
314 	rv = nlmsg_add(&nw, 0, cr->cr_xid, rpcnl_family_id, 0,
315 	    sizeof(nl->nl_hdr) + sizeof(struct nlattr) + xlen);
316 	MPASS(rv);
317 
318 	memcpy(nlmsg_reserve_data_raw(&nw, sizeof(nl->nl_hdr)), &nl->nl_hdr,
319 	    sizeof(nl->nl_hdr));
320 
321 	rv = nlattr_add(&nw, RPCNL_REQUEST_BODY, xlen, mem);
322 	MPASS(rv);
323 
324 	rv = nlmsg_end(&nw);
325 	MPASS(rv);
326 
327 	mtx_lock(&nl->nl_lock);
328 	TAILQ_INSERT_TAIL(&nl->nl_pending, cr, cr_link);
329 	mtx_unlock(&nl->nl_lock);
330 
331 	nlmsg_flush(&nw);
332 
333 	mtx_lock(&nl->nl_lock);
334 	if (__predict_true(cr->cr_error == ETIMEDOUT))
335 		(void)msleep(cr, &nl->nl_lock, nl->nl_prio, nl->nl_wchan,
336 		    (nl->nl_timo ? nl->nl_timo : tvtohz(&utimeout)) /
337 		    nl->nl_retries);
338 	TAILQ_REMOVE(&nl->nl_pending, cr, cr_link);
339 	mtx_unlock(&nl->nl_lock);
340 
341 	if (__predict_true(cr->cr_error == 0)) {
342 		struct rpc_msg reply_msg = {
343 			.acpted_rply.ar_verf.oa_base = cr->cr_verf,
344 			.acpted_rply.ar_results.proc = (xdrproc_t)xdr_void,
345 		};
346 
347 		MPASS(cr->cr_mrep);
348 		if (ext && ext->rc_feedback)
349 			ext->rc_feedback(FEEDBACK_OK, proc,
350 			    ext->rc_feedback_arg);
351 		xdrmbuf_create(&xdrs, cr->cr_mrep, XDR_DECODE);
352 		rv = xdr_replymsg(&xdrs, &reply_msg);
353 		if (__predict_false(!rv)) {
354 			stat = errp->re_status = RPC_CANTDECODERES;
355 			goto out;
356 		}
357 		if ((reply_msg.rm_reply.rp_stat == MSG_ACCEPTED) &&
358 		    (reply_msg.acpted_rply.ar_stat == SUCCESS)) {
359 			struct mbuf *results;
360 
361                         stat = errp->re_status = RPC_SUCCESS;
362 			results = xdrmbuf_getall(&xdrs);
363 			if (__predict_true(AUTH_VALIDATE(auth, cr->cr_xid,
364 			    &reply_msg.acpted_rply.ar_verf, &results))) {
365                                 MPASS(results);
366                                 *resultsp = results;
367 				/* end successful completion */
368 			} else {
369 				stat = errp->re_status = RPC_AUTHERROR;
370 				errp->re_why = AUTH_INVALIDRESP;
371 			}
372 		} else {
373 			stat = _seterr_reply(&reply_msg, errp);
374 		}
375 		xdr_destroy(&xdrs);	/* frees cr->cr_mrep */
376 	} else {
377 		MPASS(cr->cr_mrep == NULL);
378 		errp->re_errno = cr->cr_error;
379 		stat = errp->re_status = RPC_CANTRECV;
380 		if (cr->cr_error == ETIMEDOUT && ++retries < nl->nl_retries) {
381 			cr->cr_xid = atomic_fetchadd_32(&nl->nl_xid, 1);
382 			goto retry;
383 		}
384 	}
385 out:
386 	free(cr, M_RPC);
387 	free(mem, M_RPC);
388 
389 	return (stat);
390 }
391 
392 static int
clnt_nl_reply(struct nlmsghdr * hdr,struct nl_pstate * npt)393 clnt_nl_reply(struct nlmsghdr *hdr, struct nl_pstate *npt)
394 {
395 	struct nl_reply_parsed attrs = {};
396 	struct nl_data *nl;
397 	struct ct_request *cr;
398 	struct mchain mc;
399 	int error;
400 
401 	CURVNET_ASSERT_SET();
402 
403 	if ((error = nl_parse_nlmsg(hdr, &rpcnl_parser, npt, &attrs)) != 0)
404 		return (error);
405 	if (attrs.data == NULL)
406 		return (EINVAL);
407 
408 	error = mc_get(&mc, NLA_DATA_LEN(attrs.data), M_WAITOK, MT_DATA, 0);
409 	MPASS(error == 0);
410 	m_copyback(mc_first(&mc), 0, NLA_DATA_LEN(attrs.data),
411 	    NLA_DATA(attrs.data));
412 
413 	rw_rlock(&rpcnl_global_lock);
414 	if ((nl = RB_FIND(nl_data_t, &rpcnl_clients,
415 	    &(struct nl_data){ .nl_hdr.group = attrs.group })) == NULL) {
416 		rw_runlock(&rpcnl_global_lock);
417 		mc_freem(&mc);
418 		return (EPROGUNAVAIL);
419 	};
420 	mtx_lock(&nl->nl_lock);
421 	rw_runlock(&rpcnl_global_lock);
422 
423 	TAILQ_FOREACH(cr, &nl->nl_pending, cr_link)
424 		if (cr->cr_xid == hdr->nlmsg_seq
425 #ifdef VIMAGE
426 		    && cr->cr_vnet == curvnet
427 #endif
428 		    )
429 			break;
430 	if (cr == NULL) {
431 		mtx_unlock(&nl->nl_lock);
432 		mc_freem(&mc);
433 		return (EPROCUNAVAIL);
434 	}
435 	cr->cr_mrep = mc_first(&mc);
436 	cr->cr_error = 0;
437 	wakeup(cr);
438 	mtx_unlock(&nl->nl_lock);
439 
440 	return (0);
441 }
442 
443 static void
clnt_nl_close(CLIENT * cl)444 clnt_nl_close(CLIENT *cl)
445 {
446 	struct nl_data *nl =  cl->cl_private;
447 	struct ct_request *cr;
448 
449 	mtx_lock(&nl->nl_lock);
450 	TAILQ_FOREACH(cr, &nl->nl_pending, cr_link) {
451 		cr->cr_error = ESHUTDOWN;
452 		wakeup(cr);
453 	}
454 	mtx_unlock(&nl->nl_lock);
455 }
456 
457 static void
clnt_nl_destroy(CLIENT * cl)458 clnt_nl_destroy(CLIENT *cl)
459 {
460 	struct nl_data *nl = cl->cl_private;
461 
462 	MPASS(TAILQ_EMPTY(&nl->nl_pending));
463 
464 	genl_unregister_group(rpcnl_family_id, nl->nl_hdr.group);
465 	rw_wlock(&rpcnl_global_lock);
466 	RB_REMOVE(nl_data_t, &rpcnl_clients, nl);
467 	rw_wlock(&rpcnl_global_lock);
468 
469 	mtx_destroy(&nl->nl_lock);
470 	free(nl, M_RPC);
471 	free(cl, M_RPC);
472 }
473 
474 static bool_t
clnt_nl_control(CLIENT * cl,u_int request,void * info)475 clnt_nl_control(CLIENT *cl, u_int request, void *info)
476 {
477 	struct nl_data *nl = (struct nl_data *)cl->cl_private;
478 
479 	mtx_lock(&nl->nl_lock);
480 	switch (request) {
481 	case CLSET_TIMEOUT:
482 		nl->nl_timo = tvtohz((struct timeval *)info);
483 		break;
484 
485 	case CLGET_TIMEOUT:
486 		*(struct timeval *)info =
487 		    (struct timeval){.tv_sec = nl->nl_timo / hz};
488 		break;
489 
490 	case CLSET_RETRIES:
491 		nl->nl_retries = *(u_int *)info;
492 		break;
493 
494 	case CLSET_WAITCHAN:
495 		nl->nl_wchan = (const char *)info;
496 		break;
497 
498 	case CLGET_WAITCHAN:
499 		*(const char **)info = nl->nl_wchan;
500 		break;
501 
502 	case CLSET_INTERRUPTIBLE:
503 		if (*(int *)info)
504 			nl->nl_prio |= PCATCH;
505 		else
506 			nl->nl_prio &= ~PCATCH;
507 		break;
508 
509 	case CLGET_INTERRUPTIBLE:
510 		*(int *)info = (nl->nl_prio & PCATCH) ? TRUE : FALSE;
511 		break;
512 
513 	default:
514 		mtx_unlock(&nl->nl_lock);
515 		printf("%s: unsupported request %u\n", __func__, request);
516 		return (FALSE);
517 	}
518 
519 	mtx_unlock(&nl->nl_lock);
520 	return (TRUE);
521 }
522