xref: /freebsd/sbin/hastd/proto_socketpair.c (revision 8d20be1e22095c27faf8fe8b2f0d089739cc742e)
1 /*-
2  * Copyright (c) 2009-2010 The FreeBSD Foundation
3  * All rights reserved.
4  *
5  * This software was developed by Pawel Jakub Dawidek under sponsorship from
6  * the FreeBSD Foundation.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in the
15  *    documentation and/or other materials provided with the distribution.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
21  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29 
30 #include <sys/cdefs.h>
31 __FBSDID("$FreeBSD$");
32 
33 #include <sys/types.h>
34 #include <sys/socket.h>
35 
36 #include <errno.h>
37 #include <stdbool.h>
38 #include <stdint.h>
39 #include <stdio.h>
40 #include <string.h>
41 #include <unistd.h>
42 
43 #include "pjdlog.h"
44 #include "proto_impl.h"
45 
46 #define	SP_CTX_MAGIC	0x50c3741
47 struct sp_ctx {
48 	int			sp_magic;
49 	int			sp_fd[2];
50 	int			sp_side;
51 #define	SP_SIDE_UNDEF		0
52 #define	SP_SIDE_CLIENT		1
53 #define	SP_SIDE_SERVER		2
54 };
55 
56 static void sp_close(void *ctx);
57 
58 static int
59 sp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
60 {
61 	struct sp_ctx *spctx;
62 	int ret;
63 
64 	if (strcmp(dstaddr, "socketpair://") != 0)
65 		return (-1);
66 
67 	PJDLOG_ASSERT(srcaddr == NULL);
68 
69 	spctx = malloc(sizeof(*spctx));
70 	if (spctx == NULL)
71 		return (errno);
72 
73 	if (socketpair(PF_UNIX, SOCK_STREAM, 0, spctx->sp_fd) == -1) {
74 		ret = errno;
75 		free(spctx);
76 		return (ret);
77 	}
78 
79 	spctx->sp_side = SP_SIDE_UNDEF;
80 	spctx->sp_magic = SP_CTX_MAGIC;
81 	*ctxp = spctx;
82 
83 	return (0);
84 }
85 
86 static int
87 sp_send(void *ctx, const unsigned char *data, size_t size, int fd)
88 {
89 	struct sp_ctx *spctx = ctx;
90 	int sock;
91 
92 	PJDLOG_ASSERT(spctx != NULL);
93 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
94 
95 	switch (spctx->sp_side) {
96 	case SP_SIDE_UNDEF:
97 		/*
98 		 * If the first operation done by the caller is proto_send(),
99 		 * we assume this is the client.
100 		 */
101 		/* FALLTHROUGH */
102 		spctx->sp_side = SP_SIDE_CLIENT;
103 		/* Close other end. */
104 		close(spctx->sp_fd[1]);
105 		spctx->sp_fd[1] = -1;
106 	case SP_SIDE_CLIENT:
107 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
108 		sock = spctx->sp_fd[0];
109 		break;
110 	case SP_SIDE_SERVER:
111 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
112 		sock = spctx->sp_fd[1];
113 		break;
114 	default:
115 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
116 	}
117 
118 	/* Someone is just trying to decide about side. */
119 	if (data == NULL)
120 		return (0);
121 
122 	return (proto_common_send(sock, data, size, fd));
123 }
124 
125 static int
126 sp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
127 {
128 	struct sp_ctx *spctx = ctx;
129 	int fd;
130 
131 	PJDLOG_ASSERT(spctx != NULL);
132 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
133 
134 	switch (spctx->sp_side) {
135 	case SP_SIDE_UNDEF:
136 		/*
137 		 * If the first operation done by the caller is proto_recv(),
138 		 * we assume this is the server.
139 		 */
140 		/* FALLTHROUGH */
141 		spctx->sp_side = SP_SIDE_SERVER;
142 		/* Close other end. */
143 		close(spctx->sp_fd[0]);
144 		spctx->sp_fd[0] = -1;
145 	case SP_SIDE_SERVER:
146 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
147 		fd = spctx->sp_fd[1];
148 		break;
149 	case SP_SIDE_CLIENT:
150 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
151 		fd = spctx->sp_fd[0];
152 		break;
153 	default:
154 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
155 	}
156 
157 	/* Someone is just trying to decide about side. */
158 	if (data == NULL)
159 		return (0);
160 
161 	return (proto_common_recv(fd, data, size, fdp));
162 }
163 
164 static int
165 sp_descriptor(const void *ctx)
166 {
167 	const struct sp_ctx *spctx = ctx;
168 
169 	PJDLOG_ASSERT(spctx != NULL);
170 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
171 	PJDLOG_ASSERT(spctx->sp_side == SP_SIDE_CLIENT ||
172 	    spctx->sp_side == SP_SIDE_SERVER);
173 
174 	switch (spctx->sp_side) {
175 	case SP_SIDE_CLIENT:
176 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
177 		return (spctx->sp_fd[0]);
178 	case SP_SIDE_SERVER:
179 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
180 		return (spctx->sp_fd[1]);
181 	}
182 
183 	PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
184 }
185 
186 static void
187 sp_close(void *ctx)
188 {
189 	struct sp_ctx *spctx = ctx;
190 
191 	PJDLOG_ASSERT(spctx != NULL);
192 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
193 
194 	switch (spctx->sp_side) {
195 	case SP_SIDE_UNDEF:
196 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
197 		close(spctx->sp_fd[0]);
198 		spctx->sp_fd[0] = -1;
199 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
200 		close(spctx->sp_fd[1]);
201 		spctx->sp_fd[1] = -1;
202 		break;
203 	case SP_SIDE_CLIENT:
204 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
205 		close(spctx->sp_fd[0]);
206 		spctx->sp_fd[0] = -1;
207 		PJDLOG_ASSERT(spctx->sp_fd[1] == -1);
208 		break;
209 	case SP_SIDE_SERVER:
210 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
211 		close(spctx->sp_fd[1]);
212 		spctx->sp_fd[1] = -1;
213 		PJDLOG_ASSERT(spctx->sp_fd[0] == -1);
214 		break;
215 	default:
216 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
217 	}
218 
219 	spctx->sp_magic = 0;
220 	free(spctx);
221 }
222 
223 static struct proto sp_proto = {
224 	.prt_name = "socketpair",
225 	.prt_client = sp_client,
226 	.prt_send = sp_send,
227 	.prt_recv = sp_recv,
228 	.prt_descriptor = sp_descriptor,
229 	.prt_close = sp_close
230 };
231 
232 static __constructor void
233 sp_ctor(void)
234 {
235 
236 	proto_register(&sp_proto, false);
237 }
238