xref: /freebsd/sbin/hastd/proto_socketpair.c (revision ad30f8e79bd1007cc2476e491bd21b4f5e389e0a)
1 /*-
2  * Copyright (c) 2009-2010 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29 
30 #include <sys/cdefs.h>
31 __FBSDID("$FreeBSD$");
32 
33 #include <sys/types.h>
34 #include <sys/socket.h>
35 
36 #include <errno.h>
37 #include <stdbool.h>
38 #include <stdint.h>
39 #include <stdio.h>
40 #include <string.h>
41 #include <unistd.h>
42 
43 #include "hast.h"
44 #include "pjdlog.h"
45 #include "proto_impl.h"
46 
47 #define	SP_CTX_MAGIC	0x50c3741
48 struct sp_ctx {
49 	int			sp_magic;
50 	int			sp_fd[2];
51 	int			sp_side;
52 #define	SP_SIDE_UNDEF		0
53 #define	SP_SIDE_CLIENT		1
54 #define	SP_SIDE_SERVER		2
55 };
56 
57 static void sp_close(void *ctx);
58 
59 static int
60 sp_client(const char *addr, void **ctxp)
61 {
62 	struct sp_ctx *spctx;
63 	int ret;
64 
65 	if (strcmp(addr, "socketpair://") != 0)
66 		return (-1);
67 
68 	spctx = malloc(sizeof(*spctx));
69 	if (spctx == NULL)
70 		return (errno);
71 
72 	if (socketpair(PF_UNIX, SOCK_STREAM, 0, spctx->sp_fd) < 0) {
73 		ret = errno;
74 		free(spctx);
75 		return (ret);
76 	}
77 
78 	spctx->sp_side = SP_SIDE_UNDEF;
79 	spctx->sp_magic = SP_CTX_MAGIC;
80 	*ctxp = spctx;
81 
82 	return (0);
83 }
84 
85 static int
86 sp_send(void *ctx, const unsigned char *data, size_t size, int fd)
87 {
88 	struct sp_ctx *spctx = ctx;
89 	int sock;
90 
91 	PJDLOG_ASSERT(spctx != NULL);
92 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
93 
94 	switch (spctx->sp_side) {
95 	case SP_SIDE_UNDEF:
96 		/*
97 		 * If the first operation done by the caller is proto_send(),
98 		 * we assume this is the client.
99 		 */
100 		/* FALLTHROUGH */
101 		spctx->sp_side = SP_SIDE_CLIENT;
102 		/* Close other end. */
103 		close(spctx->sp_fd[1]);
104 		spctx->sp_fd[1] = -1;
105 	case SP_SIDE_CLIENT:
106 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
107 		sock = spctx->sp_fd[0];
108 		break;
109 	case SP_SIDE_SERVER:
110 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
111 		sock = spctx->sp_fd[1];
112 		break;
113 	default:
114 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
115 	}
116 
117 	/* Someone is just trying to decide about side. */
118 	if (data == NULL)
119 		return (0);
120 
121 	return (proto_common_send(sock, data, size, fd));
122 }
123 
124 static int
125 sp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
126 {
127 	struct sp_ctx *spctx = ctx;
128 	int fd;
129 
130 	PJDLOG_ASSERT(spctx != NULL);
131 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
132 
133 	switch (spctx->sp_side) {
134 	case SP_SIDE_UNDEF:
135 		/*
136 		 * If the first operation done by the caller is proto_recv(),
137 		 * we assume this is the server.
138 		 */
139 		/* FALLTHROUGH */
140 		spctx->sp_side = SP_SIDE_SERVER;
141 		/* Close other end. */
142 		close(spctx->sp_fd[0]);
143 		spctx->sp_fd[0] = -1;
144 	case SP_SIDE_SERVER:
145 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
146 		fd = spctx->sp_fd[1];
147 		break;
148 	case SP_SIDE_CLIENT:
149 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
150 		fd = spctx->sp_fd[0];
151 		break;
152 	default:
153 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
154 	}
155 
156 	/* Someone is just trying to decide about side. */
157 	if (data == NULL)
158 		return (0);
159 
160 	return (proto_common_recv(fd, data, size, fdp));
161 }
162 
163 static int
164 sp_descriptor(const void *ctx)
165 {
166 	const struct sp_ctx *spctx = ctx;
167 
168 	PJDLOG_ASSERT(spctx != NULL);
169 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
170 	PJDLOG_ASSERT(spctx->sp_side == SP_SIDE_CLIENT ||
171 	    spctx->sp_side == SP_SIDE_SERVER);
172 
173 	switch (spctx->sp_side) {
174 	case SP_SIDE_CLIENT:
175 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
176 		return (spctx->sp_fd[0]);
177 	case SP_SIDE_SERVER:
178 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
179 		return (spctx->sp_fd[1]);
180 	}
181 
182 	PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
183 }
184 
185 static void
186 sp_close(void *ctx)
187 {
188 	struct sp_ctx *spctx = ctx;
189 
190 	PJDLOG_ASSERT(spctx != NULL);
191 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
192 
193 	switch (spctx->sp_side) {
194 	case SP_SIDE_UNDEF:
195 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
196 		close(spctx->sp_fd[0]);
197 		spctx->sp_fd[0] = -1;
198 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
199 		close(spctx->sp_fd[1]);
200 		spctx->sp_fd[1] = -1;
201 		break;
202 	case SP_SIDE_CLIENT:
203 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
204 		close(spctx->sp_fd[0]);
205 		spctx->sp_fd[0] = -1;
206 		PJDLOG_ASSERT(spctx->sp_fd[1] == -1);
207 		break;
208 	case SP_SIDE_SERVER:
209 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
210 		close(spctx->sp_fd[1]);
211 		spctx->sp_fd[1] = -1;
212 		PJDLOG_ASSERT(spctx->sp_fd[0] == -1);
213 		break;
214 	default:
215 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
216 	}
217 
218 	spctx->sp_magic = 0;
219 	free(spctx);
220 }
221 
222 static struct hast_proto sp_proto = {
223 	.hp_name = "socketpair",
224 	.hp_client = sp_client,
225 	.hp_send = sp_send,
226 	.hp_recv = sp_recv,
227 	.hp_descriptor = sp_descriptor,
228 	.hp_close = sp_close
229 };
230 
231 static __constructor void
232 sp_ctor(void)
233 {
234 
235 	proto_register(&sp_proto, false);
236 }
237