xref: /freebsd/sbin/hastd/proto.c (revision 87b759f0fa1f7554d50ce640c40138512bbded44)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2009-2010 The FreeBSD Foundation
5  *
6  * This software was developed by Pawel Jakub Dawidek under sponsorship from
7  * the FreeBSD Foundation.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  * 1. Redistributions of source code must retain the above copyright
13  *    notice, this list of conditions and the following disclaimer.
14  * 2. Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the distribution.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
19  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
22  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
24  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
25  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
26  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
27  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
28  * SUCH DAMAGE.
29  */
30 
31 #include <sys/types.h>
32 #include <sys/queue.h>
33 #include <sys/socket.h>
34 
35 #include <errno.h>
36 #include <stdint.h>
37 #include <string.h>
38 #include <strings.h>
39 
40 #include "pjdlog.h"
41 #include "proto.h"
42 #include "proto_impl.h"
43 
44 #define	PROTO_CONN_MAGIC	0x907041c
45 struct proto_conn {
46 	int		 pc_magic;
47 	struct proto	*pc_proto;
48 	void		*pc_ctx;
49 	int		 pc_side;
50 #define	PROTO_SIDE_CLIENT		0
51 #define	PROTO_SIDE_SERVER_LISTEN	1
52 #define	PROTO_SIDE_SERVER_WORK		2
53 };
54 
55 static TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
56 
57 void
58 proto_register(struct proto *proto, bool isdefault)
59 {
60 	static bool seen_default = false;
61 
62 	if (!isdefault)
63 		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
64 	else {
65 		PJDLOG_ASSERT(!seen_default);
66 		seen_default = true;
67 		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
68 	}
69 }
70 
71 static struct proto_conn *
72 proto_alloc(struct proto *proto, int side)
73 {
74 	struct proto_conn *conn;
75 
76 	PJDLOG_ASSERT(proto != NULL);
77 	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
78 	    side == PROTO_SIDE_SERVER_LISTEN ||
79 	    side == PROTO_SIDE_SERVER_WORK);
80 
81 	conn = malloc(sizeof(*conn));
82 	if (conn != NULL) {
83 		conn->pc_proto = proto;
84 		conn->pc_side = side;
85 		conn->pc_magic = PROTO_CONN_MAGIC;
86 	}
87 	return (conn);
88 }
89 
90 static void
91 proto_free(struct proto_conn *conn)
92 {
93 
94 	PJDLOG_ASSERT(conn != NULL);
95 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
96 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
97 	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
98 	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
99 	PJDLOG_ASSERT(conn->pc_proto != NULL);
100 
101 	bzero(conn, sizeof(*conn));
102 	free(conn);
103 }
104 
105 static int
106 proto_common_setup(const char *srcaddr, const char *dstaddr,
107     struct proto_conn **connp, int side)
108 {
109 	struct proto *proto;
110 	struct proto_conn *conn;
111 	void *ctx;
112 	int ret;
113 
114 	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
115 	    side == PROTO_SIDE_SERVER_LISTEN);
116 
117 	TAILQ_FOREACH(proto, &protos, prt_next) {
118 		if (side == PROTO_SIDE_CLIENT) {
119 			if (proto->prt_client == NULL)
120 				ret = -1;
121 			else
122 				ret = proto->prt_client(srcaddr, dstaddr, &ctx);
123 		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
124 			if (proto->prt_server == NULL)
125 				ret = -1;
126 			else
127 				ret = proto->prt_server(dstaddr, &ctx);
128 		}
129 		/*
130 		 * ret == 0  - success
131 		 * ret == -1 - dstaddr is not for this protocol
132 		 * ret > 0   - right protocol, but an error occurred
133 		 */
134 		if (ret >= 0)
135 			break;
136 	}
137 	if (proto == NULL) {
138 		/* Unrecognized address. */
139 		errno = EINVAL;
140 		return (-1);
141 	}
142 	if (ret > 0) {
143 		/* An error occurred. */
144 		errno = ret;
145 		return (-1);
146 	}
147 	conn = proto_alloc(proto, side);
148 	if (conn == NULL) {
149 		if (proto->prt_close != NULL)
150 			proto->prt_close(ctx);
151 		errno = ENOMEM;
152 		return (-1);
153 	}
154 	conn->pc_ctx = ctx;
155 	*connp = conn;
156 
157 	return (0);
158 }
159 
160 int
161 proto_client(const char *srcaddr, const char *dstaddr,
162     struct proto_conn **connp)
163 {
164 
165 	return (proto_common_setup(srcaddr, dstaddr, connp, PROTO_SIDE_CLIENT));
166 }
167 
168 int
169 proto_connect(struct proto_conn *conn, int timeout)
170 {
171 	int ret;
172 
173 	PJDLOG_ASSERT(conn != NULL);
174 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
175 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
176 	PJDLOG_ASSERT(conn->pc_proto != NULL);
177 	PJDLOG_ASSERT(conn->pc_proto->prt_connect != NULL);
178 	PJDLOG_ASSERT(timeout >= -1);
179 
180 	ret = conn->pc_proto->prt_connect(conn->pc_ctx, timeout);
181 	if (ret != 0) {
182 		errno = ret;
183 		return (-1);
184 	}
185 
186 	return (0);
187 }
188 
189 int
190 proto_connect_wait(struct proto_conn *conn, int timeout)
191 {
192 	int ret;
193 
194 	PJDLOG_ASSERT(conn != NULL);
195 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
196 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
197 	PJDLOG_ASSERT(conn->pc_proto != NULL);
198 	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
199 	PJDLOG_ASSERT(timeout >= 0);
200 
201 	ret = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
202 	if (ret != 0) {
203 		errno = ret;
204 		return (-1);
205 	}
206 
207 	return (0);
208 }
209 
210 int
211 proto_server(const char *addr, struct proto_conn **connp)
212 {
213 
214 	return (proto_common_setup(NULL, addr, connp, PROTO_SIDE_SERVER_LISTEN));
215 }
216 
217 int
218 proto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
219 {
220 	struct proto_conn *newconn;
221 	int ret;
222 
223 	PJDLOG_ASSERT(conn != NULL);
224 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
225 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
226 	PJDLOG_ASSERT(conn->pc_proto != NULL);
227 	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
228 
229 	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
230 	if (newconn == NULL)
231 		return (-1);
232 
233 	ret = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
234 	if (ret != 0) {
235 		proto_free(newconn);
236 		errno = ret;
237 		return (-1);
238 	}
239 
240 	*newconnp = newconn;
241 
242 	return (0);
243 }
244 
245 int
246 proto_send(const struct proto_conn *conn, const void *data, size_t size)
247 {
248 	int ret;
249 
250 	PJDLOG_ASSERT(conn != NULL);
251 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
252 	PJDLOG_ASSERT(conn->pc_proto != NULL);
253 	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
254 
255 	ret = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
256 	if (ret != 0) {
257 		errno = ret;
258 		return (-1);
259 	}
260 	return (0);
261 }
262 
263 int
264 proto_recv(const struct proto_conn *conn, void *data, size_t size)
265 {
266 	int ret;
267 
268 	PJDLOG_ASSERT(conn != NULL);
269 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
270 	PJDLOG_ASSERT(conn->pc_proto != NULL);
271 	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
272 
273 	ret = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
274 	if (ret != 0) {
275 		errno = ret;
276 		return (-1);
277 	}
278 	return (0);
279 }
280 
281 int
282 proto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
283 {
284 	const char *protoname;
285 	int ret, fd;
286 
287 	PJDLOG_ASSERT(conn != NULL);
288 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
289 	PJDLOG_ASSERT(conn->pc_proto != NULL);
290 	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
291 	PJDLOG_ASSERT(mconn != NULL);
292 	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
293 	PJDLOG_ASSERT(mconn->pc_proto != NULL);
294 	fd = proto_descriptor(mconn);
295 	PJDLOG_ASSERT(fd >= 0);
296 	protoname = mconn->pc_proto->prt_name;
297 	PJDLOG_ASSERT(protoname != NULL);
298 
299 	ret = conn->pc_proto->prt_send(conn->pc_ctx,
300 	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
301 	proto_close(mconn);
302 	if (ret != 0) {
303 		errno = ret;
304 		return (-1);
305 	}
306 	return (0);
307 }
308 
309 int
310 proto_connection_recv(const struct proto_conn *conn, bool client,
311     struct proto_conn **newconnp)
312 {
313 	char protoname[128];
314 	struct proto *proto;
315 	struct proto_conn *newconn;
316 	int ret, fd;
317 
318 	PJDLOG_ASSERT(conn != NULL);
319 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
320 	PJDLOG_ASSERT(conn->pc_proto != NULL);
321 	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
322 	PJDLOG_ASSERT(newconnp != NULL);
323 
324 	bzero(protoname, sizeof(protoname));
325 
326 	ret = conn->pc_proto->prt_recv(conn->pc_ctx, (unsigned char *)protoname,
327 	    sizeof(protoname) - 1, &fd);
328 	if (ret != 0) {
329 		errno = ret;
330 		return (-1);
331 	}
332 
333 	PJDLOG_ASSERT(fd >= 0);
334 
335 	TAILQ_FOREACH(proto, &protos, prt_next) {
336 		if (strcmp(proto->prt_name, protoname) == 0)
337 			break;
338 	}
339 	if (proto == NULL) {
340 		errno = EINVAL;
341 		return (-1);
342 	}
343 
344 	newconn = proto_alloc(proto,
345 	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
346 	if (newconn == NULL)
347 		return (-1);
348 	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
349 	ret = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
350 	if (ret != 0) {
351 		proto_free(newconn);
352 		errno = ret;
353 		return (-1);
354 	}
355 
356 	*newconnp = newconn;
357 
358 	return (0);
359 }
360 
361 int
362 proto_descriptor(const struct proto_conn *conn)
363 {
364 
365 	PJDLOG_ASSERT(conn != NULL);
366 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
367 	PJDLOG_ASSERT(conn->pc_proto != NULL);
368 	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
369 
370 	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
371 }
372 
373 bool
374 proto_address_match(const struct proto_conn *conn, const char *addr)
375 {
376 
377 	PJDLOG_ASSERT(conn != NULL);
378 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
379 	PJDLOG_ASSERT(conn->pc_proto != NULL);
380 	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
381 
382 	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
383 }
384 
385 void
386 proto_local_address(const struct proto_conn *conn, char *addr, size_t size)
387 {
388 
389 	PJDLOG_ASSERT(conn != NULL);
390 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
391 	PJDLOG_ASSERT(conn->pc_proto != NULL);
392 	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
393 
394 	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
395 }
396 
397 void
398 proto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
399 {
400 
401 	PJDLOG_ASSERT(conn != NULL);
402 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
403 	PJDLOG_ASSERT(conn->pc_proto != NULL);
404 	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
405 
406 	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
407 }
408 
409 int
410 proto_timeout(const struct proto_conn *conn, int timeout)
411 {
412 	struct timeval tv;
413 	int fd;
414 
415 	PJDLOG_ASSERT(conn != NULL);
416 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
417 	PJDLOG_ASSERT(conn->pc_proto != NULL);
418 
419 	fd = proto_descriptor(conn);
420 	if (fd == -1)
421 		return (-1);
422 
423 	tv.tv_sec = timeout;
424 	tv.tv_usec = 0;
425 	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1)
426 		return (-1);
427 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1)
428 		return (-1);
429 
430 	return (0);
431 }
432 
433 void
434 proto_close(struct proto_conn *conn)
435 {
436 
437 	PJDLOG_ASSERT(conn != NULL);
438 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
439 	PJDLOG_ASSERT(conn->pc_proto != NULL);
440 	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
441 
442 	conn->pc_proto->prt_close(conn->pc_ctx);
443 	proto_free(conn);
444 }
445