xref: /freebsd/sbin/hastd/proto_socketpair.c (revision 43e29d03f416d7dda52112a29600a7c82ee1a91e)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2009-2010 The FreeBSD Foundation
5  * All rights reserved.
6  *
7  * This software was developed by Pawel Jakub Dawidek under sponsorship from
8  * the FreeBSD Foundation.
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
20  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
25  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
26  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
27  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
28  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29  * SUCH DAMAGE.
30  */
31 
32 #include <sys/cdefs.h>
33 __FBSDID("$FreeBSD$");
34 
35 #include <sys/types.h>
36 #include <sys/socket.h>
37 
38 #include <errno.h>
39 #include <stdbool.h>
40 #include <stdint.h>
41 #include <stdio.h>
42 #include <string.h>
43 #include <unistd.h>
44 
45 #include "pjdlog.h"
46 #include "proto_impl.h"
47 
48 #define	SP_CTX_MAGIC	0x50c3741
49 struct sp_ctx {
50 	int			sp_magic;
51 	int			sp_fd[2];
52 	int			sp_side;
53 #define	SP_SIDE_UNDEF		0
54 #define	SP_SIDE_CLIENT		1
55 #define	SP_SIDE_SERVER		2
56 };
57 
58 static void sp_close(void *ctx);
59 
60 static int
61 sp_client(const char *srcaddr, const char *dstaddr, void **ctxp)
62 {
63 	struct sp_ctx *spctx;
64 	int ret;
65 
66 	if (strcmp(dstaddr, "socketpair://") != 0)
67 		return (-1);
68 
69 	PJDLOG_ASSERT(srcaddr == NULL);
70 
71 	spctx = malloc(sizeof(*spctx));
72 	if (spctx == NULL)
73 		return (errno);
74 
75 	if (socketpair(PF_UNIX, SOCK_STREAM, 0, spctx->sp_fd) == -1) {
76 		ret = errno;
77 		free(spctx);
78 		return (ret);
79 	}
80 
81 	spctx->sp_side = SP_SIDE_UNDEF;
82 	spctx->sp_magic = SP_CTX_MAGIC;
83 	*ctxp = spctx;
84 
85 	return (0);
86 }
87 
88 static int
89 sp_send(void *ctx, const unsigned char *data, size_t size, int fd)
90 {
91 	struct sp_ctx *spctx = ctx;
92 	int sock;
93 
94 	PJDLOG_ASSERT(spctx != NULL);
95 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
96 
97 	switch (spctx->sp_side) {
98 	case SP_SIDE_UNDEF:
99 		/*
100 		 * If the first operation done by the caller is proto_send(),
101 		 * we assume this is the client.
102 		 */
103 		/* FALLTHROUGH */
104 		spctx->sp_side = SP_SIDE_CLIENT;
105 		/* Close other end. */
106 		close(spctx->sp_fd[1]);
107 		spctx->sp_fd[1] = -1;
108 	case SP_SIDE_CLIENT:
109 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
110 		sock = spctx->sp_fd[0];
111 		break;
112 	case SP_SIDE_SERVER:
113 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
114 		sock = spctx->sp_fd[1];
115 		break;
116 	default:
117 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
118 	}
119 
120 	/* Someone is just trying to decide about side. */
121 	if (data == NULL)
122 		return (0);
123 
124 	return (proto_common_send(sock, data, size, fd));
125 }
126 
127 static int
128 sp_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
129 {
130 	struct sp_ctx *spctx = ctx;
131 	int fd;
132 
133 	PJDLOG_ASSERT(spctx != NULL);
134 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
135 
136 	switch (spctx->sp_side) {
137 	case SP_SIDE_UNDEF:
138 		/*
139 		 * If the first operation done by the caller is proto_recv(),
140 		 * we assume this is the server.
141 		 */
142 		/* FALLTHROUGH */
143 		spctx->sp_side = SP_SIDE_SERVER;
144 		/* Close other end. */
145 		close(spctx->sp_fd[0]);
146 		spctx->sp_fd[0] = -1;
147 	case SP_SIDE_SERVER:
148 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
149 		fd = spctx->sp_fd[1];
150 		break;
151 	case SP_SIDE_CLIENT:
152 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
153 		fd = spctx->sp_fd[0];
154 		break;
155 	default:
156 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
157 	}
158 
159 	/* Someone is just trying to decide about side. */
160 	if (data == NULL)
161 		return (0);
162 
163 	return (proto_common_recv(fd, data, size, fdp));
164 }
165 
166 static int
167 sp_descriptor(const void *ctx)
168 {
169 	const struct sp_ctx *spctx = ctx;
170 
171 	PJDLOG_ASSERT(spctx != NULL);
172 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
173 	PJDLOG_ASSERT(spctx->sp_side == SP_SIDE_CLIENT ||
174 	    spctx->sp_side == SP_SIDE_SERVER);
175 
176 	switch (spctx->sp_side) {
177 	case SP_SIDE_CLIENT:
178 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
179 		return (spctx->sp_fd[0]);
180 	case SP_SIDE_SERVER:
181 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
182 		return (spctx->sp_fd[1]);
183 	}
184 
185 	PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
186 }
187 
188 static void
189 sp_close(void *ctx)
190 {
191 	struct sp_ctx *spctx = ctx;
192 
193 	PJDLOG_ASSERT(spctx != NULL);
194 	PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC);
195 
196 	switch (spctx->sp_side) {
197 	case SP_SIDE_UNDEF:
198 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
199 		close(spctx->sp_fd[0]);
200 		spctx->sp_fd[0] = -1;
201 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
202 		close(spctx->sp_fd[1]);
203 		spctx->sp_fd[1] = -1;
204 		break;
205 	case SP_SIDE_CLIENT:
206 		PJDLOG_ASSERT(spctx->sp_fd[0] >= 0);
207 		close(spctx->sp_fd[0]);
208 		spctx->sp_fd[0] = -1;
209 		PJDLOG_ASSERT(spctx->sp_fd[1] == -1);
210 		break;
211 	case SP_SIDE_SERVER:
212 		PJDLOG_ASSERT(spctx->sp_fd[1] >= 0);
213 		close(spctx->sp_fd[1]);
214 		spctx->sp_fd[1] = -1;
215 		PJDLOG_ASSERT(spctx->sp_fd[0] == -1);
216 		break;
217 	default:
218 		PJDLOG_ABORT("Invalid socket side (%d).", spctx->sp_side);
219 	}
220 
221 	spctx->sp_magic = 0;
222 	free(spctx);
223 }
224 
225 static struct proto sp_proto = {
226 	.prt_name = "socketpair",
227 	.prt_client = sp_client,
228 	.prt_send = sp_send,
229 	.prt_recv = sp_recv,
230 	.prt_descriptor = sp_descriptor,
231 	.prt_close = sp_close
232 };
233 
234 static __constructor void
235 sp_ctor(void)
236 {
237 
238 	proto_register(&sp_proto, false);
239 }
240