xref: /freebsd/sbin/hastd/proto.c (revision 3823d5e198425b4f5e5a80267d195769d1063773)
1 /*-
2  * Copyright (c) 2009-2010 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29 
30 #include <sys/cdefs.h>
31 __FBSDID("$FreeBSD$");
32 
33 #include <sys/types.h>
34 #include <sys/queue.h>
35 #include <sys/socket.h>
36 
37 #include <errno.h>
38 #include <stdint.h>
39 #include <string.h>
40 #include <strings.h>
41 
42 #include "pjdlog.h"
43 #include "proto.h"
44 #include "proto_impl.h"
45 
46 #define	PROTO_CONN_MAGIC	0x907041c
47 struct proto_conn {
48 	int		 pc_magic;
49 	struct proto	*pc_proto;
50 	void		*pc_ctx;
51 	int		 pc_side;
52 #define	PROTO_SIDE_CLIENT		0
53 #define	PROTO_SIDE_SERVER_LISTEN	1
54 #define	PROTO_SIDE_SERVER_WORK		2
55 };
56 
57 static TAILQ_HEAD(, proto) protos = TAILQ_HEAD_INITIALIZER(protos);
58 
59 void
60 proto_register(struct proto *proto, bool isdefault)
61 {
62 	static bool seen_default = false;
63 
64 	if (!isdefault)
65 		TAILQ_INSERT_HEAD(&protos, proto, prt_next);
66 	else {
67 		PJDLOG_ASSERT(!seen_default);
68 		seen_default = true;
69 		TAILQ_INSERT_TAIL(&protos, proto, prt_next);
70 	}
71 }
72 
73 static struct proto_conn *
74 proto_alloc(struct proto *proto, int side)
75 {
76 	struct proto_conn *conn;
77 
78 	PJDLOG_ASSERT(proto != NULL);
79 	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
80 	    side == PROTO_SIDE_SERVER_LISTEN ||
81 	    side == PROTO_SIDE_SERVER_WORK);
82 
83 	conn = malloc(sizeof(*conn));
84 	if (conn != NULL) {
85 		conn->pc_proto = proto;
86 		conn->pc_side = side;
87 		conn->pc_magic = PROTO_CONN_MAGIC;
88 	}
89 	return (conn);
90 }
91 
92 static void
93 proto_free(struct proto_conn *conn)
94 {
95 
96 	PJDLOG_ASSERT(conn != NULL);
97 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
98 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT ||
99 	    conn->pc_side == PROTO_SIDE_SERVER_LISTEN ||
100 	    conn->pc_side == PROTO_SIDE_SERVER_WORK);
101 	PJDLOG_ASSERT(conn->pc_proto != NULL);
102 
103 	bzero(conn, sizeof(*conn));
104 	free(conn);
105 }
106 
107 static int
108 proto_common_setup(const char *srcaddr, const char *dstaddr,
109     struct proto_conn **connp, int side)
110 {
111 	struct proto *proto;
112 	struct proto_conn *conn;
113 	void *ctx;
114 	int ret;
115 
116 	PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT ||
117 	    side == PROTO_SIDE_SERVER_LISTEN);
118 
119 	TAILQ_FOREACH(proto, &protos, prt_next) {
120 		if (side == PROTO_SIDE_CLIENT) {
121 			if (proto->prt_client == NULL)
122 				ret = -1;
123 			else
124 				ret = proto->prt_client(srcaddr, dstaddr, &ctx);
125 		} else /* if (side == PROTO_SIDE_SERVER_LISTEN) */ {
126 			if (proto->prt_server == NULL)
127 				ret = -1;
128 			else
129 				ret = proto->prt_server(dstaddr, &ctx);
130 		}
131 		/*
132 		 * ret == 0  - success
133 		 * ret == -1 - dstaddr is not for this protocol
134 		 * ret > 0   - right protocol, but an error occurred
135 		 */
136 		if (ret >= 0)
137 			break;
138 	}
139 	if (proto == NULL) {
140 		/* Unrecognized address. */
141 		errno = EINVAL;
142 		return (-1);
143 	}
144 	if (ret > 0) {
145 		/* An error occurred. */
146 		errno = ret;
147 		return (-1);
148 	}
149 	conn = proto_alloc(proto, side);
150 	if (conn == NULL) {
151 		if (proto->prt_close != NULL)
152 			proto->prt_close(ctx);
153 		errno = ENOMEM;
154 		return (-1);
155 	}
156 	conn->pc_ctx = ctx;
157 	*connp = conn;
158 
159 	return (0);
160 }
161 
162 int
163 proto_client(const char *srcaddr, const char *dstaddr,
164     struct proto_conn **connp)
165 {
166 
167 	return (proto_common_setup(srcaddr, dstaddr, connp, PROTO_SIDE_CLIENT));
168 }
169 
170 int
171 proto_connect(struct proto_conn *conn, int timeout)
172 {
173 	int ret;
174 
175 	PJDLOG_ASSERT(conn != NULL);
176 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
177 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
178 	PJDLOG_ASSERT(conn->pc_proto != NULL);
179 	PJDLOG_ASSERT(conn->pc_proto->prt_connect != NULL);
180 	PJDLOG_ASSERT(timeout >= -1);
181 
182 	ret = conn->pc_proto->prt_connect(conn->pc_ctx, timeout);
183 	if (ret != 0) {
184 		errno = ret;
185 		return (-1);
186 	}
187 
188 	return (0);
189 }
190 
191 int
192 proto_connect_wait(struct proto_conn *conn, int timeout)
193 {
194 	int ret;
195 
196 	PJDLOG_ASSERT(conn != NULL);
197 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
198 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT);
199 	PJDLOG_ASSERT(conn->pc_proto != NULL);
200 	PJDLOG_ASSERT(conn->pc_proto->prt_connect_wait != NULL);
201 	PJDLOG_ASSERT(timeout >= 0);
202 
203 	ret = conn->pc_proto->prt_connect_wait(conn->pc_ctx, timeout);
204 	if (ret != 0) {
205 		errno = ret;
206 		return (-1);
207 	}
208 
209 	return (0);
210 }
211 
212 int
213 proto_server(const char *addr, struct proto_conn **connp)
214 {
215 
216 	return (proto_common_setup(NULL, addr, connp, PROTO_SIDE_SERVER_LISTEN));
217 }
218 
219 int
220 proto_accept(struct proto_conn *conn, struct proto_conn **newconnp)
221 {
222 	struct proto_conn *newconn;
223 	int ret;
224 
225 	PJDLOG_ASSERT(conn != NULL);
226 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
227 	PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_SERVER_LISTEN);
228 	PJDLOG_ASSERT(conn->pc_proto != NULL);
229 	PJDLOG_ASSERT(conn->pc_proto->prt_accept != NULL);
230 
231 	newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK);
232 	if (newconn == NULL)
233 		return (-1);
234 
235 	ret = conn->pc_proto->prt_accept(conn->pc_ctx, &newconn->pc_ctx);
236 	if (ret != 0) {
237 		proto_free(newconn);
238 		errno = ret;
239 		return (-1);
240 	}
241 
242 	*newconnp = newconn;
243 
244 	return (0);
245 }
246 
247 int
248 proto_send(const struct proto_conn *conn, const void *data, size_t size)
249 {
250 	int ret;
251 
252 	PJDLOG_ASSERT(conn != NULL);
253 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
254 	PJDLOG_ASSERT(conn->pc_proto != NULL);
255 	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
256 
257 	ret = conn->pc_proto->prt_send(conn->pc_ctx, data, size, -1);
258 	if (ret != 0) {
259 		errno = ret;
260 		return (-1);
261 	}
262 	return (0);
263 }
264 
265 int
266 proto_recv(const struct proto_conn *conn, void *data, size_t size)
267 {
268 	int ret;
269 
270 	PJDLOG_ASSERT(conn != NULL);
271 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
272 	PJDLOG_ASSERT(conn->pc_proto != NULL);
273 	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
274 
275 	ret = conn->pc_proto->prt_recv(conn->pc_ctx, data, size, NULL);
276 	if (ret != 0) {
277 		errno = ret;
278 		return (-1);
279 	}
280 	return (0);
281 }
282 
283 int
284 proto_connection_send(const struct proto_conn *conn, struct proto_conn *mconn)
285 {
286 	const char *protoname;
287 	int ret, fd;
288 
289 	PJDLOG_ASSERT(conn != NULL);
290 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
291 	PJDLOG_ASSERT(conn->pc_proto != NULL);
292 	PJDLOG_ASSERT(conn->pc_proto->prt_send != NULL);
293 	PJDLOG_ASSERT(mconn != NULL);
294 	PJDLOG_ASSERT(mconn->pc_magic == PROTO_CONN_MAGIC);
295 	PJDLOG_ASSERT(mconn->pc_proto != NULL);
296 	fd = proto_descriptor(mconn);
297 	PJDLOG_ASSERT(fd >= 0);
298 	protoname = mconn->pc_proto->prt_name;
299 	PJDLOG_ASSERT(protoname != NULL);
300 
301 	ret = conn->pc_proto->prt_send(conn->pc_ctx,
302 	    (const unsigned char *)protoname, strlen(protoname) + 1, fd);
303 	proto_close(mconn);
304 	if (ret != 0) {
305 		errno = ret;
306 		return (-1);
307 	}
308 	return (0);
309 }
310 
311 int
312 proto_connection_recv(const struct proto_conn *conn, bool client,
313     struct proto_conn **newconnp)
314 {
315 	char protoname[128];
316 	struct proto *proto;
317 	struct proto_conn *newconn;
318 	int ret, fd;
319 
320 	PJDLOG_ASSERT(conn != NULL);
321 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
322 	PJDLOG_ASSERT(conn->pc_proto != NULL);
323 	PJDLOG_ASSERT(conn->pc_proto->prt_recv != NULL);
324 	PJDLOG_ASSERT(newconnp != NULL);
325 
326 	bzero(protoname, sizeof(protoname));
327 
328 	ret = conn->pc_proto->prt_recv(conn->pc_ctx, (unsigned char *)protoname,
329 	    sizeof(protoname) - 1, &fd);
330 	if (ret != 0) {
331 		errno = ret;
332 		return (-1);
333 	}
334 
335 	PJDLOG_ASSERT(fd >= 0);
336 
337 	TAILQ_FOREACH(proto, &protos, prt_next) {
338 		if (strcmp(proto->prt_name, protoname) == 0)
339 			break;
340 	}
341 	if (proto == NULL) {
342 		errno = EINVAL;
343 		return (-1);
344 	}
345 
346 	newconn = proto_alloc(proto,
347 	    client ? PROTO_SIDE_CLIENT : PROTO_SIDE_SERVER_WORK);
348 	if (newconn == NULL)
349 		return (-1);
350 	PJDLOG_ASSERT(newconn->pc_proto->prt_wrap != NULL);
351 	ret = newconn->pc_proto->prt_wrap(fd, client, &newconn->pc_ctx);
352 	if (ret != 0) {
353 		proto_free(newconn);
354 		errno = ret;
355 		return (-1);
356 	}
357 
358 	*newconnp = newconn;
359 
360 	return (0);
361 }
362 
363 int
364 proto_descriptor(const struct proto_conn *conn)
365 {
366 
367 	PJDLOG_ASSERT(conn != NULL);
368 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
369 	PJDLOG_ASSERT(conn->pc_proto != NULL);
370 	PJDLOG_ASSERT(conn->pc_proto->prt_descriptor != NULL);
371 
372 	return (conn->pc_proto->prt_descriptor(conn->pc_ctx));
373 }
374 
375 bool
376 proto_address_match(const struct proto_conn *conn, const char *addr)
377 {
378 
379 	PJDLOG_ASSERT(conn != NULL);
380 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
381 	PJDLOG_ASSERT(conn->pc_proto != NULL);
382 	PJDLOG_ASSERT(conn->pc_proto->prt_address_match != NULL);
383 
384 	return (conn->pc_proto->prt_address_match(conn->pc_ctx, addr));
385 }
386 
387 void
388 proto_local_address(const struct proto_conn *conn, char *addr, size_t size)
389 {
390 
391 	PJDLOG_ASSERT(conn != NULL);
392 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
393 	PJDLOG_ASSERT(conn->pc_proto != NULL);
394 	PJDLOG_ASSERT(conn->pc_proto->prt_local_address != NULL);
395 
396 	conn->pc_proto->prt_local_address(conn->pc_ctx, addr, size);
397 }
398 
399 void
400 proto_remote_address(const struct proto_conn *conn, char *addr, size_t size)
401 {
402 
403 	PJDLOG_ASSERT(conn != NULL);
404 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
405 	PJDLOG_ASSERT(conn->pc_proto != NULL);
406 	PJDLOG_ASSERT(conn->pc_proto->prt_remote_address != NULL);
407 
408 	conn->pc_proto->prt_remote_address(conn->pc_ctx, addr, size);
409 }
410 
411 int
412 proto_timeout(const struct proto_conn *conn, int timeout)
413 {
414 	struct timeval tv;
415 	int fd;
416 
417 	PJDLOG_ASSERT(conn != NULL);
418 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
419 	PJDLOG_ASSERT(conn->pc_proto != NULL);
420 
421 	fd = proto_descriptor(conn);
422 	if (fd == -1)
423 		return (-1);
424 
425 	tv.tv_sec = timeout;
426 	tv.tv_usec = 0;
427 	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == -1)
428 		return (-1);
429 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1)
430 		return (-1);
431 
432 	return (0);
433 }
434 
435 void
436 proto_close(struct proto_conn *conn)
437 {
438 
439 	PJDLOG_ASSERT(conn != NULL);
440 	PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC);
441 	PJDLOG_ASSERT(conn->pc_proto != NULL);
442 	PJDLOG_ASSERT(conn->pc_proto->prt_close != NULL);
443 
444 	conn->pc_proto->prt_close(conn->pc_ctx);
445 	proto_free(conn);
446 }
447