xref: /freebsd/sbin/hastd/proto.c (revision f7c32ed617858bcd22f8d1b03199099d50125721)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2009-2010 The FreeBSD Foundation
5  * All rights reserved.
6  *
7  * This software was developed by Pawel Jakub Dawidek under sponsorship from
8  * the FreeBSD Foundation.
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
20  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
25  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
27  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
28  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29  * SUCH DAMAGE.
30  */
31 
32 #include <sys/cdefs.h>
33 __FBSDID("$FreeBSD$");
34 
35 #include <sys/types.h>
36 #include <sys/queue.h>
37 #include <sys/socket.h>
38 
39 #include <errno.h>
40 #include <stdint.h>
41 #include <string.h>
42 #include <strings.h>
43 
44 #include "pjdlog.h"
45 #include "proto.h"
46 #include "proto_impl.h"
47 
48 #define	PROTO_CONN_MAGIC	0x907041c
49 struct proto_conn {
50 	int		 pc_magic;
51 	struct proto	*pc_proto;
52 	void		*pc_ctx;
53 	int		 pc_side;
54 #define	PROTO_SIDE_CLIENT		0
55 #define	PROTO_SIDE_SERVER_LISTEN	1
56 #define	PROTO_SIDE_SERVER_WORK		2
57 };
58 
59 static TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
60 
61 void
62 proto_register(struct proto *proto, bool isdefault)
63 {
64 	static bool seen_default = false;
65 
66 	if (!isdefault)
67 		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
68 	else {
69 		PJDLOG_ASSERT(!seen_default);
70 		seen_default = true;
71 		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
72 	}
73 }
74 
75 static struct proto_conn *
76 proto_alloc(struct proto *proto, int side)
77 {
78 	struct proto_conn *conn;
79 
80 	PJDLOG_ASSERT(proto != NULL);
81 	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
82 	    side == PROTO_SIDE_SERVER_LISTEN ||
83 	    side == PROTO_SIDE_SERVER_WORK);
84 
85 	conn = malloc(sizeof(*conn));
86 	if (conn != NULL) {
87 		conn->pc_proto = proto;
88 		conn->pc_side = side;
89 		conn->pc_magic = PROTO_CONN_MAGIC;
90 	}
91 	return (conn);
92 }
93 
94 static void
95 proto_free(struct proto_conn *conn)
96 {
97 
98 	PJDLOG_ASSERT(conn != NULL);
99 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
100 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
101 	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
102 	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
103 	PJDLOG_ASSERT(conn->pc_proto != NULL);
104 
105 	bzero(conn, sizeof(*conn));
106 	free(conn);
107 }
108 
109 static int
110 proto_common_setup(const char *srcaddr, const char *dstaddr,
111     struct proto_conn **connp, int side)
112 {
113 	struct proto *proto;
114 	struct proto_conn *conn;
115 	void *ctx;
116 	int ret;
117 
118 	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
119 	    side == PROTO_SIDE_SERVER_LISTEN);
120 
121 	TAILQ_FOREACH(proto, &protos, prt_next) {
122 		if (side == PROTO_SIDE_CLIENT) {
123 			if (proto->prt_client == NULL)
124 				ret = -1;
125 			else
126 				ret = proto->prt_client(srcaddr, dstaddr, &ctx);
127 		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
128 			if (proto->prt_server == NULL)
129 				ret = -1;
130 			else
131 				ret = proto->prt_server(dstaddr, &ctx);
132 		}
133 		/*
134 		 * ret == 0  - success
135 		 * ret == -1 - dstaddr is not for this protocol
136 		 * ret > 0   - right protocol, but an error occurred
137 		 */
138 		if (ret >= 0)
139 			break;
140 	}
141 	if (proto == NULL) {
142 		/* Unrecognized address. */
143 		errno = EINVAL;
144 		return (-1);
145 	}
146 	if (ret > 0) {
147 		/* An error occurred. */
148 		errno = ret;
149 		return (-1);
150 	}
151 	conn = proto_alloc(proto, side);
152 	if (conn == NULL) {
153 		if (proto->prt_close != NULL)
154 			proto->prt_close(ctx);
155 		errno = ENOMEM;
156 		return (-1);
157 	}
158 	conn->pc_ctx = ctx;
159 	*connp = conn;
160 
161 	return (0);
162 }
163 
164 int
165 proto_client(const char *srcaddr, const char *dstaddr,
166     struct proto_conn **connp)
167 {
168 
169 	return (proto_common_setup(srcaddr, dstaddr, connp, PROTO_SIDE_CLIENT));
170 }
171 
172 int
173 proto_connect(struct proto_conn *conn, int timeout)
174 {
175 	int ret;
176 
177 	PJDLOG_ASSERT(conn != NULL);
178 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
179 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
180 	PJDLOG_ASSERT(conn->pc_proto != NULL);
181 	PJDLOG_ASSERT(conn->pc_proto->prt_connect != NULL);
182 	PJDLOG_ASSERT(timeout >= -1);
183 
184 	ret = conn->pc_proto->prt_connect(conn->pc_ctx, timeout);
185 	if (ret != 0) {
186 		errno = ret;
187 		return (-1);
188 	}
189 
190 	return (0);
191 }
192 
193 int
194 proto_connect_wait(struct proto_conn *conn, int timeout)
195 {
196 	int ret;
197 
198 	PJDLOG_ASSERT(conn != NULL);
199 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
200 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
201 	PJDLOG_ASSERT(conn->pc_proto != NULL);
202 	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
203 	PJDLOG_ASSERT(timeout >= 0);
204 
205 	ret = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
206 	if (ret != 0) {
207 		errno = ret;
208 		return (-1);
209 	}
210 
211 	return (0);
212 }
213 
214 int
215 proto_server(const char *addr, struct proto_conn **connp)
216 {
217 
218 	return (proto_common_setup(NULL, addr, connp, PROTO_SIDE_SERVER_LISTEN));
219 }
220 
221 int
222 proto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
223 {
224 	struct proto_conn *newconn;
225 	int ret;
226 
227 	PJDLOG_ASSERT(conn != NULL);
228 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
229 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
230 	PJDLOG_ASSERT(conn->pc_proto != NULL);
231 	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
232 
233 	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
234 	if (newconn == NULL)
235 		return (-1);
236 
237 	ret = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
238 	if (ret != 0) {
239 		proto_free(newconn);
240 		errno = ret;
241 		return (-1);
242 	}
243 
244 	*newconnp = newconn;
245 
246 	return (0);
247 }
248 
249 int
250 proto_send(const struct proto_conn *conn, const void *data, size_t size)
251 {
252 	int ret;
253 
254 	PJDLOG_ASSERT(conn != NULL);
255 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
256 	PJDLOG_ASSERT(conn->pc_proto != NULL);
257 	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
258 
259 	ret = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
260 	if (ret != 0) {
261 		errno = ret;
262 		return (-1);
263 	}
264 	return (0);
265 }
266 
267 int
268 proto_recv(const struct proto_conn *conn, void *data, size_t size)
269 {
270 	int ret;
271 
272 	PJDLOG_ASSERT(conn != NULL);
273 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
274 	PJDLOG_ASSERT(conn->pc_proto != NULL);
275 	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
276 
277 	ret = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
278 	if (ret != 0) {
279 		errno = ret;
280 		return (-1);
281 	}
282 	return (0);
283 }
284 
285 int
286 proto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
287 {
288 	const char *protoname;
289 	int ret, fd;
290 
291 	PJDLOG_ASSERT(conn != NULL);
292 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
293 	PJDLOG_ASSERT(conn->pc_proto != NULL);
294 	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
295 	PJDLOG_ASSERT(mconn != NULL);
296 	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
297 	PJDLOG_ASSERT(mconn->pc_proto != NULL);
298 	fd = proto_descriptor(mconn);
299 	PJDLOG_ASSERT(fd >= 0);
300 	protoname = mconn->pc_proto->prt_name;
301 	PJDLOG_ASSERT(protoname != NULL);
302 
303 	ret = conn->pc_proto->prt_send(conn->pc_ctx,
304 	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
305 	proto_close(mconn);
306 	if (ret != 0) {
307 		errno = ret;
308 		return (-1);
309 	}
310 	return (0);
311 }
312 
313 int
314 proto_connection_recv(const struct proto_conn *conn, bool client,
315     struct proto_conn **newconnp)
316 {
317 	char protoname[128];
318 	struct proto *proto;
319 	struct proto_conn *newconn;
320 	int ret, fd;
321 
322 	PJDLOG_ASSERT(conn != NULL);
323 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
324 	PJDLOG_ASSERT(conn->pc_proto != NULL);
325 	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
326 	PJDLOG_ASSERT(newconnp != NULL);
327 
328 	bzero(protoname, sizeof(protoname));
329 
330 	ret = conn->pc_proto->prt_recv(conn->pc_ctx, (unsigned char *)protoname,
331 	    sizeof(protoname) - 1, &fd);
332 	if (ret != 0) {
333 		errno = ret;
334 		return (-1);
335 	}
336 
337 	PJDLOG_ASSERT(fd >= 0);
338 
339 	TAILQ_FOREACH(proto, &protos, prt_next) {
340 		if (strcmp(proto->prt_name, protoname) == 0)
341 			break;
342 	}
343 	if (proto == NULL) {
344 		errno = EINVAL;
345 		return (-1);
346 	}
347 
348 	newconn = proto_alloc(proto,
349 	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
350 	if (newconn == NULL)
351 		return (-1);
352 	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
353 	ret = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
354 	if (ret != 0) {
355 		proto_free(newconn);
356 		errno = ret;
357 		return (-1);
358 	}
359 
360 	*newconnp = newconn;
361 
362 	return (0);
363 }
364 
365 int
366 proto_descriptor(const struct proto_conn *conn)
367 {
368 
369 	PJDLOG_ASSERT(conn != NULL);
370 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
371 	PJDLOG_ASSERT(conn->pc_proto != NULL);
372 	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
373 
374 	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
375 }
376 
377 bool
378 proto_address_match(const struct proto_conn *conn, const char *addr)
379 {
380 
381 	PJDLOG_ASSERT(conn != NULL);
382 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
383 	PJDLOG_ASSERT(conn->pc_proto != NULL);
384 	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
385 
386 	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
387 }
388 
389 void
390 proto_local_address(const struct proto_conn *conn, char *addr, size_t size)
391 {
392 
393 	PJDLOG_ASSERT(conn != NULL);
394 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
395 	PJDLOG_ASSERT(conn->pc_proto != NULL);
396 	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
397 
398 	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
399 }
400 
401 void
402 proto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
403 {
404 
405 	PJDLOG_ASSERT(conn != NULL);
406 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
407 	PJDLOG_ASSERT(conn->pc_proto != NULL);
408 	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
409 
410 	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
411 }
412 
413 int
414 proto_timeout(const struct proto_conn *conn, int timeout)
415 {
416 	struct timeval tv;
417 	int fd;
418 
419 	PJDLOG_ASSERT(conn != NULL);
420 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
421 	PJDLOG_ASSERT(conn->pc_proto != NULL);
422 
423 	fd = proto_descriptor(conn);
424 	if (fd == -1)
425 		return (-1);
426 
427 	tv.tv_sec = timeout;
428 	tv.tv_usec = 0;
429 	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1)
430 		return (-1);
431 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1)
432 		return (-1);
433 
434 	return (0);
435 }
436 
437 void
438 proto_close(struct proto_conn *conn)
439 {
440 
441 	PJDLOG_ASSERT(conn != NULL);
442 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
443 	PJDLOG_ASSERT(conn->pc_proto != NULL);
444 	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
445 
446 	conn->pc_proto->prt_close(conn->pc_ctx);
447 	proto_free(conn);
448 }
449