xref: /freebsd/sbin/hastd/proto_uds.c (revision ad30f8e79bd1007cc2476e491bd21b4f5e389e0a)
1 /*-
2  * Copyright (c) 2009-2010 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29 
30 #include <sys/cdefs.h>
31 __FBSDID("$FreeBSD$");
32 
33 /* UDS - UNIX Domain Socket */
34 
35 #include <sys/types.h>
36 #include <sys/un.h>
37 
38 #include <errno.h>
39 #include <stdbool.h>
40 #include <stdint.h>
41 #include <stdio.h>
42 #include <string.h>
43 #include <unistd.h>
44 
45 #include "hast.h"
46 #include "pjdlog.h"
47 #include "proto_impl.h"
48 
49 #define	UDS_CTX_MAGIC	0xd541c
50 struct uds_ctx {
51 	int			uc_magic;
52 	struct sockaddr_un	uc_sun;
53 	int			uc_fd;
54 	int			uc_side;
55 #define	UDS_SIDE_CLIENT		0
56 #define	UDS_SIDE_SERVER_LISTEN	1
57 #define	UDS_SIDE_SERVER_WORK	2
58 	pid_t			uc_owner;
59 };
60 
61 static void uds_close(void *ctx);
62 
63 static int
64 uds_addr(const char *addr, struct sockaddr_un *sunp)
65 {
66 
67 	if (addr == NULL)
68 		return (-1);
69 
70 	if (strncasecmp(addr, "uds://", 6) == 0)
71 		addr += 6;
72 	else if (strncasecmp(addr, "unix://", 7) == 0)
73 		addr += 7;
74 	else if (addr[0] == '/' &&	/* If it starts from /... */
75 	    strstr(addr, "://") == NULL)/* ...and there is no prefix... */
76 		;			/* ...we assume its us. */
77 	else
78 		return (-1);
79 
80 	sunp->sun_family = AF_UNIX;
81 	if (strlcpy(sunp->sun_path, addr, sizeof(sunp->sun_path)) >=
82 	    sizeof(sunp->sun_path)) {
83 		return (ENAMETOOLONG);
84 	}
85 	sunp->sun_len = SUN_LEN(sunp);
86 
87 	return (0);
88 }
89 
90 static int
91 uds_common_setup(const char *addr, void **ctxp, int side)
92 {
93 	struct uds_ctx *uctx;
94 	int ret;
95 
96 	uctx = malloc(sizeof(*uctx));
97 	if (uctx == NULL)
98 		return (errno);
99 
100 	/* Parse given address. */
101 	if ((ret = uds_addr(addr, &uctx->uc_sun)) != 0) {
102 		free(uctx);
103 		return (ret);
104 	}
105 
106 	uctx->uc_fd = socket(AF_UNIX, SOCK_STREAM, 0);
107 	if (uctx->uc_fd == -1) {
108 		ret = errno;
109 		free(uctx);
110 		return (ret);
111 	}
112 
113 	uctx->uc_side = side;
114 	uctx->uc_owner = 0;
115 	uctx->uc_magic = UDS_CTX_MAGIC;
116 	*ctxp = uctx;
117 
118 	return (0);
119 }
120 
121 static int
122 uds_client(const char *addr, void **ctxp)
123 {
124 
125 	return (uds_common_setup(addr, ctxp, UDS_SIDE_CLIENT));
126 }
127 
128 static int
129 uds_connect(void *ctx, int timeout)
130 {
131 	struct uds_ctx *uctx = ctx;
132 
133 	PJDLOG_ASSERT(uctx != NULL);
134 	PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC);
135 	PJDLOG_ASSERT(uctx->uc_side == UDS_SIDE_CLIENT);
136 	PJDLOG_ASSERT(uctx->uc_fd >= 0);
137 	PJDLOG_ASSERT(timeout >= -1);
138 
139 	if (connect(uctx->uc_fd, (struct sockaddr *)&uctx->uc_sun,
140 	    sizeof(uctx->uc_sun)) < 0) {
141 		return (errno);
142 	}
143 
144 	return (0);
145 }
146 
147 static int
148 uds_connect_wait(void *ctx, int timeout)
149 {
150 	struct uds_ctx *uctx = ctx;
151 
152 	PJDLOG_ASSERT(uctx != NULL);
153 	PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC);
154 	PJDLOG_ASSERT(uctx->uc_side == UDS_SIDE_CLIENT);
155 	PJDLOG_ASSERT(uctx->uc_fd >= 0);
156 	PJDLOG_ASSERT(timeout >= 0);
157 
158 	return (0);
159 }
160 
161 static int
162 uds_server(const char *addr, void **ctxp)
163 {
164 	struct uds_ctx *uctx;
165 	int ret;
166 
167 	ret = uds_common_setup(addr, ctxp, UDS_SIDE_SERVER_LISTEN);
168 	if (ret != 0)
169 		return (ret);
170 
171 	uctx = *ctxp;
172 
173 	(void)unlink(uctx->uc_sun.sun_path);
174 	if (bind(uctx->uc_fd, (struct sockaddr *)&uctx->uc_sun,
175 	    sizeof(uctx->uc_sun)) < 0) {
176 		ret = errno;
177 		uds_close(uctx);
178 		return (ret);
179 	}
180 	uctx->uc_owner = getpid();
181 	if (listen(uctx->uc_fd, 8) < 0) {
182 		ret = errno;
183 		uds_close(uctx);
184 		return (ret);
185 	}
186 
187 	return (0);
188 }
189 
190 static int
191 uds_accept(void *ctx, void **newctxp)
192 {
193 	struct uds_ctx *uctx = ctx;
194 	struct uds_ctx *newuctx;
195 	socklen_t fromlen;
196 	int ret;
197 
198 	PJDLOG_ASSERT(uctx != NULL);
199 	PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC);
200 	PJDLOG_ASSERT(uctx->uc_side == UDS_SIDE_SERVER_LISTEN);
201 	PJDLOG_ASSERT(uctx->uc_fd >= 0);
202 
203 	newuctx = malloc(sizeof(*newuctx));
204 	if (newuctx == NULL)
205 		return (errno);
206 
207 	fromlen = sizeof(newuctx->uc_sun);
208 	newuctx->uc_fd = accept(uctx->uc_fd,
209 	    (struct sockaddr *)&newuctx->uc_sun, &fromlen);
210 	if (newuctx->uc_fd < 0) {
211 		ret = errno;
212 		free(newuctx);
213 		return (ret);
214 	}
215 
216 	newuctx->uc_side = UDS_SIDE_SERVER_WORK;
217 	newuctx->uc_magic = UDS_CTX_MAGIC;
218 	*newctxp = newuctx;
219 
220 	return (0);
221 }
222 
223 static int
224 uds_send(void *ctx, const unsigned char *data, size_t size, int fd)
225 {
226 	struct uds_ctx *uctx = ctx;
227 
228 	PJDLOG_ASSERT(uctx != NULL);
229 	PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC);
230 	PJDLOG_ASSERT(uctx->uc_fd >= 0);
231 
232 	return (proto_common_send(uctx->uc_fd, data, size, fd));
233 }
234 
235 static int
236 uds_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
237 {
238 	struct uds_ctx *uctx = ctx;
239 
240 	PJDLOG_ASSERT(uctx != NULL);
241 	PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC);
242 	PJDLOG_ASSERT(uctx->uc_fd >= 0);
243 
244 	return (proto_common_recv(uctx->uc_fd, data, size, fdp));
245 }
246 
247 static int
248 uds_descriptor(const void *ctx)
249 {
250 	const struct uds_ctx *uctx = ctx;
251 
252 	PJDLOG_ASSERT(uctx != NULL);
253 	PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC);
254 
255 	return (uctx->uc_fd);
256 }
257 
258 static void
259 uds_local_address(const void *ctx, char *addr, size_t size)
260 {
261 	const struct uds_ctx *uctx = ctx;
262 	struct sockaddr_un sun;
263 	socklen_t sunlen;
264 
265 	PJDLOG_ASSERT(uctx != NULL);
266 	PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC);
267 	PJDLOG_ASSERT(addr != NULL);
268 
269 	sunlen = sizeof(sun);
270 	if (getsockname(uctx->uc_fd, (struct sockaddr *)&sun, &sunlen) < 0) {
271 		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
272 		return;
273 	}
274 	PJDLOG_ASSERT(sun.sun_family == AF_UNIX);
275 	if (sun.sun_path[0] == '\0') {
276 		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
277 		return;
278 	}
279 	PJDLOG_VERIFY(snprintf(addr, size, "uds://%s", sun.sun_path) < (ssize_t)size);
280 }
281 
282 static void
283 uds_remote_address(const void *ctx, char *addr, size_t size)
284 {
285 	const struct uds_ctx *uctx = ctx;
286 	struct sockaddr_un sun;
287 	socklen_t sunlen;
288 
289 	PJDLOG_ASSERT(uctx != NULL);
290 	PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC);
291 	PJDLOG_ASSERT(addr != NULL);
292 
293 	sunlen = sizeof(sun);
294 	if (getpeername(uctx->uc_fd, (struct sockaddr *)&sun, &sunlen) < 0) {
295 		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
296 		return;
297 	}
298 	PJDLOG_ASSERT(sun.sun_family == AF_UNIX);
299 	if (sun.sun_path[0] == '\0') {
300 		PJDLOG_VERIFY(strlcpy(addr, "N/A", size) < size);
301 		return;
302 	}
303 	snprintf(addr, size, "uds://%s", sun.sun_path);
304 }
305 
306 static void
307 uds_close(void *ctx)
308 {
309 	struct uds_ctx *uctx = ctx;
310 
311 	PJDLOG_ASSERT(uctx != NULL);
312 	PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC);
313 
314 	if (uctx->uc_fd >= 0)
315 		close(uctx->uc_fd);
316 	/*
317 	 * Unlink the socket only if we are the owner and this is descriptor
318 	 * we listen on.
319 	 */
320 	if (uctx->uc_side == UDS_SIDE_SERVER_LISTEN &&
321 	    uctx->uc_owner == getpid()) {
322 		PJDLOG_ASSERT(uctx->uc_sun.sun_path[0] != '\0');
323 		if (unlink(uctx->uc_sun.sun_path) == -1) {
324 			pjdlog_errno(LOG_WARNING,
325 			    "Unable to unlink socket file %s",
326 			    uctx->uc_sun.sun_path);
327 		}
328 	}
329 	uctx->uc_owner = 0;
330 	uctx->uc_magic = 0;
331 	free(uctx);
332 }
333 
334 static struct hast_proto uds_proto = {
335 	.hp_name = "uds",
336 	.hp_client = uds_client,
337 	.hp_connect = uds_connect,
338 	.hp_connect_wait = uds_connect_wait,
339 	.hp_server = uds_server,
340 	.hp_accept = uds_accept,
341 	.hp_send = uds_send,
342 	.hp_recv = uds_recv,
343 	.hp_descriptor = uds_descriptor,
344 	.hp_local_address = uds_local_address,
345 	.hp_remote_address = uds_remote_address,
346 	.hp_close = uds_close
347 };
348 
349 static __constructor void
350 uds_ctor(void)
351 {
352 
353 	proto_register(&uds_proto, false);
354 }
355