xref: /freebsd/contrib/openbsm/bin/auditdistd/proto.c (revision 5d3e7166f6a0187fa3f8831b16a06bd9955c21ff)
1 /*-
2  * Copyright (c) 2009-2010 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29 
30 #include <sys/types.h>
31 #include <sys/queue.h>
32 #include <sys/socket.h>
33 
34 #include <errno.h>
35 #include <stdint.h>
36 #include <string.h>
37 #include <strings.h>
38 
39 #include "pjdlog.h"
40 #include "proto.h"
41 #include "proto_impl.h"
42 
43 #define	PROTO_CONN_MAGIC	0x907041c
44 struct proto_conn {
45 	int		 pc_magic;
46 	struct proto	*pc_proto;
47 	void		*pc_ctx;
48 	int		 pc_side;
49 #define	PROTO_SIDE_CLIENT		0
50 #define	PROTO_SIDE_SERVER_LISTEN	1
51 #define	PROTO_SIDE_SERVER_WORK		2
52 };
53 
54 static TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
55 
56 void
57 proto_register(struct proto *proto, bool isdefault)
58 {
59 	static bool seen_default = false;
60 
61 	if (!isdefault)
62 		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
63 	else {
64 		PJDLOG_ASSERT(!seen_default);
65 		seen_default = true;
66 		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
67 	}
68 }
69 
70 static struct proto_conn *
71 proto_alloc(struct proto *proto, int side)
72 {
73 	struct proto_conn *conn;
74 
75 	PJDLOG_ASSERT(proto != NULL);
76 	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
77 	    side == PROTO_SIDE_SERVER_LISTEN ||
78 	    side == PROTO_SIDE_SERVER_WORK);
79 
80 	conn = malloc(sizeof(*conn));
81 	if (conn != NULL) {
82 		conn->pc_proto = proto;
83 		conn->pc_side = side;
84 		conn->pc_magic = PROTO_CONN_MAGIC;
85 	}
86 	return (conn);
87 }
88 
89 static void
90 proto_free(struct proto_conn *conn)
91 {
92 
93 	PJDLOG_ASSERT(conn != NULL);
94 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
95 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
96 	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
97 	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
98 	PJDLOG_ASSERT(conn->pc_proto != NULL);
99 
100 	bzero(conn, sizeof(*conn));
101 	free(conn);
102 }
103 
104 static int
105 proto_common_setup(const char *srcaddr, const char *dstaddr, int timeout,
106     int side, struct proto_conn **connp)
107 {
108 	struct proto *proto;
109 	struct proto_conn *conn;
110 	void *ctx;
111 	int ret;
112 
113 	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
114 	    side == PROTO_SIDE_SERVER_LISTEN);
115 
116 	TAILQ_FOREACH(proto, &protos, prt_next) {
117 		if (side == PROTO_SIDE_CLIENT) {
118 			if (proto->prt_connect == NULL) {
119 				ret = -1;
120 			} else {
121 				ret = proto->prt_connect(srcaddr, dstaddr,
122 				    timeout, &ctx);
123 			}
124 		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
125 			if (proto->prt_server == NULL)
126 				ret = -1;
127 			else
128 				ret = proto->prt_server(dstaddr, &ctx);
129 		}
130 		/*
131 		 * ret == 0  - success
132 		 * ret == -1 - dstaddr is not for this protocol
133 		 * ret > 0   - right protocol, but an error occured
134 		 */
135 		if (ret >= 0)
136 			break;
137 	}
138 	if (proto == NULL) {
139 		/* Unrecognized address. */
140 		errno = EINVAL;
141 		return (-1);
142 	}
143 	if (ret > 0) {
144 		/* An error occured. */
145 		errno = ret;
146 		return (-1);
147 	}
148 	conn = proto_alloc(proto, side);
149 	if (conn == NULL) {
150 		if (proto->prt_close != NULL)
151 			proto->prt_close(ctx);
152 		errno = ENOMEM;
153 		return (-1);
154 	}
155 	conn->pc_ctx = ctx;
156 	*connp = conn;
157 
158 	return (0);
159 }
160 
161 int
162 proto_connect(const char *srcaddr, const char *dstaddr, int timeout,
163     struct proto_conn **connp)
164 {
165 
166 	PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
167 	PJDLOG_ASSERT(dstaddr != NULL);
168 	PJDLOG_ASSERT(timeout >= -1);
169 
170 	return (proto_common_setup(srcaddr, dstaddr, timeout,
171 	    PROTO_SIDE_CLIENT, connp));
172 }
173 
174 int
175 proto_connect_wait(struct proto_conn *conn, int timeout)
176 {
177 	int error;
178 
179 	PJDLOG_ASSERT(conn != NULL);
180 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
181 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
182 	PJDLOG_ASSERT(conn->pc_proto != NULL);
183 	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
184 	PJDLOG_ASSERT(timeout >= 0);
185 
186 	error = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
187 	if (error != 0) {
188 		errno = error;
189 		return (-1);
190 	}
191 
192 	return (0);
193 }
194 
195 int
196 proto_server(const char *addr, struct proto_conn **connp)
197 {
198 
199 	PJDLOG_ASSERT(addr != NULL);
200 
201 	return (proto_common_setup(NULL, addr, -1, PROTO_SIDE_SERVER_LISTEN,
202 	    connp));
203 }
204 
205 int
206 proto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
207 {
208 	struct proto_conn *newconn;
209 	int error;
210 
211 	PJDLOG_ASSERT(conn != NULL);
212 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
213 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
214 	PJDLOG_ASSERT(conn->pc_proto != NULL);
215 	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
216 
217 	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
218 	if (newconn == NULL)
219 		return (-1);
220 
221 	error = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
222 	if (error != 0) {
223 		proto_free(newconn);
224 		errno = error;
225 		return (-1);
226 	}
227 
228 	*newconnp = newconn;
229 
230 	return (0);
231 }
232 
233 int
234 proto_send(const struct proto_conn *conn, const void *data, size_t size)
235 {
236 	int error;
237 
238 	PJDLOG_ASSERT(conn != NULL);
239 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
240 	PJDLOG_ASSERT(conn->pc_proto != NULL);
241 	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
242 
243 	error = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
244 	if (error != 0) {
245 		errno = error;
246 		return (-1);
247 	}
248 	return (0);
249 }
250 
251 int
252 proto_recv(const struct proto_conn *conn, void *data, size_t size)
253 {
254 	int error;
255 
256 	PJDLOG_ASSERT(conn != NULL);
257 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
258 	PJDLOG_ASSERT(conn->pc_proto != NULL);
259 	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
260 
261 	error = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
262 	if (error != 0) {
263 		errno = error;
264 		return (-1);
265 	}
266 	return (0);
267 }
268 
269 int
270 proto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
271 {
272 	const char *protoname;
273 	int error, fd;
274 
275 	PJDLOG_ASSERT(conn != NULL);
276 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
277 	PJDLOG_ASSERT(conn->pc_proto != NULL);
278 	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
279 	PJDLOG_ASSERT(mconn != NULL);
280 	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
281 	PJDLOG_ASSERT(mconn->pc_proto != NULL);
282 	fd = proto_descriptor(mconn);
283 	PJDLOG_ASSERT(fd >= 0);
284 	protoname = mconn->pc_proto->prt_name;
285 	PJDLOG_ASSERT(protoname != NULL);
286 
287 	error = conn->pc_proto->prt_send(conn->pc_ctx,
288 	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
289 	proto_close(mconn);
290 	if (error != 0) {
291 		errno = error;
292 		return (-1);
293 	}
294 	return (0);
295 }
296 
297 int
298 proto_wrap(const char *protoname, bool client, int fd,
299     struct proto_conn **newconnp)
300 {
301 	struct proto *proto;
302 	struct proto_conn *newconn;
303 	int error;
304 
305 	TAILQ_FOREACH(proto, &protos, prt_next) {
306 		if (strcmp(proto->prt_name, protoname) == 0)
307 			break;
308 	}
309 	if (proto == NULL) {
310 		errno = EINVAL;
311 		return (-1);
312 	}
313 
314 	newconn = proto_alloc(proto,
315 	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
316 	if (newconn == NULL)
317 		return (-1);
318 	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
319 	error = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
320 	if (error != 0) {
321 		proto_free(newconn);
322 		errno = error;
323 		return (-1);
324 	}
325 
326 	*newconnp = newconn;
327 
328 	return (0);
329 }
330 
331 int
332 proto_connection_recv(const struct proto_conn *conn, bool client,
333     struct proto_conn **newconnp)
334 {
335 	char protoname[128];
336 	int error, fd;
337 
338 	PJDLOG_ASSERT(conn != NULL);
339 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
340 	PJDLOG_ASSERT(conn->pc_proto != NULL);
341 	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
342 	PJDLOG_ASSERT(newconnp != NULL);
343 
344 	bzero(protoname, sizeof(protoname));
345 
346 	error = conn->pc_proto->prt_recv(conn->pc_ctx,
347 	    (unsigned char *)protoname, sizeof(protoname) - 1, &fd);
348 	if (error != 0) {
349 		errno = error;
350 		return (-1);
351 	}
352 
353 	PJDLOG_ASSERT(fd >= 0);
354 
355 	return (proto_wrap(protoname, client, fd, newconnp));
356 }
357 
358 int
359 proto_descriptor(const struct proto_conn *conn)
360 {
361 
362 	PJDLOG_ASSERT(conn != NULL);
363 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
364 	PJDLOG_ASSERT(conn->pc_proto != NULL);
365 	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
366 
367 	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
368 }
369 
370 bool
371 proto_address_match(const struct proto_conn *conn, const char *addr)
372 {
373 
374 	PJDLOG_ASSERT(conn != NULL);
375 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
376 	PJDLOG_ASSERT(conn->pc_proto != NULL);
377 	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
378 
379 	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
380 }
381 
382 void
383 proto_local_address(const struct proto_conn *conn, char *addr, size_t size)
384 {
385 
386 	PJDLOG_ASSERT(conn != NULL);
387 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
388 	PJDLOG_ASSERT(conn->pc_proto != NULL);
389 	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
390 
391 	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
392 }
393 
394 void
395 proto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
396 {
397 
398 	PJDLOG_ASSERT(conn != NULL);
399 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
400 	PJDLOG_ASSERT(conn->pc_proto != NULL);
401 	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
402 
403 	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
404 }
405 
406 int
407 proto_timeout(const struct proto_conn *conn, int timeout)
408 {
409 	struct timeval tv;
410 	int fd;
411 
412 	PJDLOG_ASSERT(conn != NULL);
413 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
414 	PJDLOG_ASSERT(conn->pc_proto != NULL);
415 
416 	fd = proto_descriptor(conn);
417 	if (fd < 0)
418 		return (-1);
419 
420 	tv.tv_sec = timeout;
421 	tv.tv_usec = 0;
422 	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) < 0)
423 		return (-1);
424 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0)
425 		return (-1);
426 
427 	return (0);
428 }
429 
430 void
431 proto_close(struct proto_conn *conn)
432 {
433 
434 	PJDLOG_ASSERT(conn != NULL);
435 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
436 	PJDLOG_ASSERT(conn->pc_proto != NULL);
437 	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
438 
439 	conn->pc_proto->prt_close(conn->pc_ctx);
440 	proto_free(conn);
441 }
442 
443 int
444 proto_exec(int argc, char *argv[])
445 {
446 	struct proto *proto;
447 	int error;
448 
449 	if (argc == 0) {
450 		errno = EINVAL;
451 		return (-1);
452 	}
453 	TAILQ_FOREACH(proto, &protos, prt_next) {
454 		if (strcmp(proto->prt_name, argv[0]) == 0)
455 			break;
456 	}
457 	if (proto == NULL) {
458 		errno = EINVAL;
459 		return (-1);
460 	}
461 	if (proto->prt_exec == NULL) {
462 		errno = EOPNOTSUPP;
463 		return (-1);
464 	}
465 	error = proto->prt_exec(argc, argv);
466 	if (error != 0) {
467 		errno = error;
468 		return (-1);
469 	}
470 	/* NOTREACHED */
471 	return (0);
472 }
473 
474 struct proto_nvpair {
475 	char	*pnv_name;
476 	char	*pnv_value;
477 	TAILQ_ENTRY(proto_nvpair) pnv_next;
478 };
479 
480 static TAILQ_HEAD(, proto_nvpair) proto_nvpairs =
481     TAILQ_HEAD_INITIALIZER(proto_nvpairs);
482 
483 int
484 proto_set(const char *name, const char *value)
485 {
486 	struct proto_nvpair *pnv;
487 
488 	TAILQ_FOREACH(pnv, &proto_nvpairs, pnv_next) {
489 		if (strcmp(pnv->pnv_name, name) == 0)
490 			break;
491 	}
492 	if (pnv != NULL) {
493 		TAILQ_REMOVE(&proto_nvpairs, pnv, pnv_next);
494 		free(pnv->pnv_value);
495 	} else {
496 		pnv = malloc(sizeof(*pnv));
497 		if (pnv == NULL)
498 			return (-1);
499 		pnv->pnv_name = strdup(name);
500 		if (pnv->pnv_name == NULL) {
501 			free(pnv);
502 			return (-1);
503 		}
504 	}
505 	pnv->pnv_value = strdup(value);
506 	if (pnv->pnv_value == NULL) {
507 		free(pnv->pnv_name);
508 		free(pnv);
509 		return (-1);
510 	}
511 	TAILQ_INSERT_TAIL(&proto_nvpairs, pnv, pnv_next);
512 	return (0);
513 }
514 
515 const char *
516 proto_get(const char *name)
517 {
518 	struct proto_nvpair *pnv;
519 
520 	TAILQ_FOREACH(pnv, &proto_nvpairs, pnv_next) {
521 		if (strcmp(pnv->pnv_name, name) == 0)
522 			break;
523 	}
524 	if (pnv != NULL)
525 		return (pnv->pnv_value);
526 	return (NULL);
527 }
528