1 /*
2  * Copyright 2016 Jakub Klama <jceel@FreeBSD.org>
3  * All rights reserved
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted providing that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
18  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
22  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
23  * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24  * POSSIBILITY OF SUCH DAMAGE.
25  *
26  */
27 
28 #include <stdlib.h>
29 #include <errno.h>
30 #include <string.h>
31 #include <unistd.h>
32 #include <pthread.h>
33 #include <assert.h>
34 #include <sys/types.h>
35 #ifdef __APPLE__
36 # include "../apple_endian.h"
37 #else
38 # include <sys/endian.h>
39 #endif
40 #include <sys/socket.h>
41 #include <sys/event.h>
42 #include <sys/uio.h>
43 #include <netdb.h>
44 #include "../lib9p.h"
45 #include "../lib9p_impl.h"
46 #include "../log.h"
47 #include "socket.h"
48 
49 struct l9p_socket_softc
50 {
51 	struct l9p_connection *ls_conn;
52 	struct sockaddr ls_sockaddr;
53 	socklen_t ls_socklen;
54 	pthread_t ls_thread;
55 	int ls_fd;
56 };
57 
58 static int l9p_socket_readmsg(struct l9p_socket_softc *, void **, size_t *);
59 static int l9p_socket_get_response_buffer(struct l9p_request *,
60     struct iovec *, size_t *, void *);
61 static int l9p_socket_send_response(struct l9p_request *, const struct iovec *,
62     const size_t, const size_t, void *);
63 static void l9p_socket_drop_response(struct l9p_request *, const struct iovec *,
64     size_t, void *);
65 static void *l9p_socket_thread(void *);
66 static ssize_t xread(int, void *, size_t);
67 static ssize_t xwrite(int, void *, size_t);
68 
69 int
l9p_start_server(struct l9p_server * server,const char * host,const char * port)70 l9p_start_server(struct l9p_server *server, const char *host, const char *port)
71 {
72 	struct addrinfo *res, *res0, hints;
73 	struct kevent kev[2];
74 	struct kevent event[2];
75 	int err, kq, i, val, evs, nsockets = 0;
76 
77 	memset(&hints, 0, sizeof(hints));
78 	hints.ai_family = PF_UNSPEC;
79 	hints.ai_socktype = SOCK_STREAM;
80 	err = getaddrinfo(host, port, &hints, &res0);
81 
82 	if (err)
83 		return (-1);
84 
85 	for (res = res0; res; res = res->ai_next) {
86 		int s = socket(res->ai_family, res->ai_socktype,
87 		    res->ai_protocol);
88 
89 		val = 1;
90 		setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
91 
92 		if (s < 0)
93 			continue;
94 
95 		if (bind(s, res->ai_addr, res->ai_addrlen) < 0) {
96 			close(s);
97 			continue;
98 		}
99 
100 		EV_SET(&kev[nsockets++], s, EVFILT_READ, EV_ADD | EV_ENABLE, 0,
101 		    0, 0);
102 		listen(s, 10);
103 	}
104 
105 	if (nsockets < 1) {
106 		L9P_LOG(L9P_ERROR, "bind(): %s", strerror(errno));
107 		return(-1);
108 	}
109 
110 	kq = kqueue();
111 
112 	if (kevent(kq, kev, nsockets, NULL, 0, NULL) < 0) {
113 		L9P_LOG(L9P_ERROR, "kevent(): %s", strerror(errno));
114 		return (-1);
115 	}
116 
117 	for (;;) {
118 		evs = kevent(kq, NULL, 0, event, nsockets, NULL);
119 		if (evs < 0) {
120 			if (errno == EINTR)
121 				continue;
122 
123 			L9P_LOG(L9P_ERROR, "kevent(): %s", strerror(errno));
124 			return (-1);
125 		}
126 
127 		for (i = 0; i < evs; i++) {
128 			struct sockaddr client_addr;
129 			socklen_t client_addr_len = sizeof(client_addr);
130 			int news = accept((int)event[i].ident, &client_addr,
131 			    &client_addr_len);
132 
133 			if (news < 0) {
134 				L9P_LOG(L9P_WARNING, "accept(): %s",
135 				    strerror(errno));
136 				continue;
137 			}
138 
139 			l9p_socket_accept(server, news, &client_addr,
140 			    client_addr_len);
141 		}
142 	}
143 
144 }
145 
146 void
l9p_socket_accept(struct l9p_server * server,int conn_fd,struct sockaddr * client_addr,socklen_t client_addr_len)147 l9p_socket_accept(struct l9p_server *server, int conn_fd,
148     struct sockaddr *client_addr, socklen_t client_addr_len)
149 {
150 	struct l9p_socket_softc *sc;
151 	struct l9p_connection *conn;
152 	char host[NI_MAXHOST + 1];
153 	char serv[NI_MAXSERV + 1];
154 	int err;
155 
156 	err = getnameinfo(client_addr, client_addr_len, host, NI_MAXHOST, serv,
157 	    NI_MAXSERV, NI_NUMERICHOST | NI_NUMERICSERV);
158 
159 	if (err != 0) {
160 		L9P_LOG(L9P_WARNING, "cannot look up client name: %s",
161 		    gai_strerror(err));
162 	} else {
163 		L9P_LOG(L9P_INFO, "new connection from %s:%s", host, serv);
164 	}
165 
166 	if (l9p_connection_init(server, &conn) != 0) {
167 		L9P_LOG(L9P_ERROR, "cannot create new connection");
168 		return;
169 	}
170 
171 	sc = l9p_calloc(1, sizeof(*sc));
172 	sc->ls_conn = conn;
173 	sc->ls_fd = conn_fd;
174 
175 	/*
176 	 * Fill in transport handler functions and aux argument.
177 	 */
178 	conn->lc_lt.lt_aux = sc;
179 	conn->lc_lt.lt_get_response_buffer = l9p_socket_get_response_buffer;
180 	conn->lc_lt.lt_send_response = l9p_socket_send_response;
181 	conn->lc_lt.lt_drop_response = l9p_socket_drop_response;
182 
183 	err = pthread_create(&sc->ls_thread, NULL, l9p_socket_thread, sc);
184 	if (err) {
185 		L9P_LOG(L9P_ERROR,
186 		    "pthread_create (for connection from %s:%s): error %s",
187 		    host, serv, strerror(err));
188 		l9p_connection_close(sc->ls_conn);
189 		free(sc);
190 	}
191 }
192 
193 static void *
l9p_socket_thread(void * arg)194 l9p_socket_thread(void *arg)
195 {
196 	struct l9p_socket_softc *sc = (struct l9p_socket_softc *)arg;
197 	struct iovec iov;
198 	void *buf;
199 	size_t length;
200 
201 	for (;;) {
202 		if (l9p_socket_readmsg(sc, &buf, &length) != 0)
203 			break;
204 
205 		iov.iov_base = buf;
206 		iov.iov_len = length;
207 		l9p_connection_recv(sc->ls_conn, &iov, 1, NULL);
208 		free(buf);
209 	}
210 
211 	L9P_LOG(L9P_INFO, "connection closed");
212 	l9p_connection_close(sc->ls_conn);
213 	free(sc);
214 	return (NULL);
215 }
216 
217 static int
l9p_socket_readmsg(struct l9p_socket_softc * sc,void ** buf,size_t * size)218 l9p_socket_readmsg(struct l9p_socket_softc *sc, void **buf, size_t *size)
219 {
220 	uint32_t msize;
221 	size_t toread;
222 	ssize_t ret;
223 	void *buffer;
224 	int fd = sc->ls_fd;
225 
226 	assert(fd > 0);
227 
228 	buffer = l9p_malloc(sizeof(uint32_t));
229 
230 	ret = xread(fd, buffer, sizeof(uint32_t));
231 	if (ret < 0) {
232 		L9P_LOG(L9P_ERROR, "read(): %s", strerror(errno));
233 		return (-1);
234 	}
235 
236 	if (ret != sizeof(uint32_t)) {
237 		if (ret == 0)
238 			L9P_LOG(L9P_DEBUG, "%p: EOF", (void *)sc->ls_conn);
239 		else
240 			L9P_LOG(L9P_ERROR,
241 			    "short read: %zd bytes of %zd expected",
242 			    ret, sizeof(uint32_t));
243 		return (-1);
244 	}
245 
246 	msize = le32toh(*(uint32_t *)buffer);
247 	toread = msize - sizeof(uint32_t);
248 	buffer = l9p_realloc(buffer, msize);
249 
250 	ret = xread(fd, (char *)buffer + sizeof(uint32_t), toread);
251 	if (ret < 0) {
252 		L9P_LOG(L9P_ERROR, "read(): %s", strerror(errno));
253 		return (-1);
254 	}
255 
256 	if (ret != (ssize_t)toread) {
257 		L9P_LOG(L9P_ERROR, "short read: %zd bytes of %zd expected",
258 		    ret, toread);
259 		return (-1);
260 	}
261 
262 	*size = msize;
263 	*buf = buffer;
264 	L9P_LOG(L9P_INFO, "%p: read complete message, buf=%p size=%d",
265 	    (void *)sc->ls_conn, buffer, msize);
266 
267 	return (0);
268 }
269 
270 static int
l9p_socket_get_response_buffer(struct l9p_request * req,struct iovec * iov,size_t * niovp,void * arg __unused)271 l9p_socket_get_response_buffer(struct l9p_request *req, struct iovec *iov,
272     size_t *niovp, void *arg __unused)
273 {
274 	size_t size = req->lr_conn->lc_msize;
275 	void *buf;
276 
277 	buf = l9p_malloc(size);
278 	iov[0].iov_base = buf;
279 	iov[0].iov_len = size;
280 
281 	*niovp = 1;
282 	return (0);
283 }
284 
285 static int
l9p_socket_send_response(struct l9p_request * req __unused,const struct iovec * iov,const size_t niov __unused,const size_t iolen,void * arg)286 l9p_socket_send_response(struct l9p_request *req __unused,
287     const struct iovec *iov, const size_t niov __unused, const size_t iolen,
288     void *arg)
289 {
290 	struct l9p_socket_softc *sc = (struct l9p_socket_softc *)arg;
291 
292 	assert(sc->ls_fd >= 0);
293 
294 	L9P_LOG(L9P_DEBUG, "%p: sending reply, buf=%p, size=%d", arg,
295 	    iov[0].iov_base, iolen);
296 
297 	if (xwrite(sc->ls_fd, iov[0].iov_base, iolen) != (int)iolen) {
298 		L9P_LOG(L9P_ERROR, "short write: %s", strerror(errno));
299 		return (-1);
300 	}
301 
302 	free(iov[0].iov_base);
303 	return (0);
304 }
305 
306 static void
l9p_socket_drop_response(struct l9p_request * req __unused,const struct iovec * iov,size_t niov __unused,void * arg __unused)307 l9p_socket_drop_response(struct l9p_request *req __unused,
308     const struct iovec *iov, size_t niov __unused, void *arg __unused)
309 {
310 
311 	L9P_LOG(L9P_DEBUG, "%p: drop buf=%p", arg, iov[0].iov_base);
312 	free(iov[0].iov_base);
313 }
314 
315 static ssize_t
xread(int fd,void * buf,size_t count)316 xread(int fd, void *buf, size_t count)
317 {
318 	size_t done = 0;
319 	ssize_t ret;
320 
321 	while (done < count) {
322 		ret = read(fd, (char *)buf + done, count - done);
323 		if (ret < 0) {
324 			if (errno == EINTR)
325 				continue;
326 
327 			return (-1);
328 		}
329 
330 		if (ret == 0)
331 			return ((ssize_t)done);
332 
333 		done += (size_t)ret;
334 	}
335 
336 	return ((ssize_t)done);
337 }
338 
339 static ssize_t
xwrite(int fd,void * buf,size_t count)340 xwrite(int fd, void *buf, size_t count)
341 {
342 	size_t done = 0;
343 	ssize_t ret;
344 
345 	while (done < count) {
346 		ret = write(fd, (char *)buf + done, count - done);
347 		if (ret < 0) {
348 			if (errno == EINTR)
349 				continue;
350 
351 			return (-1);
352 		}
353 
354 		if (ret == 0)
355 			return ((ssize_t)done);
356 
357 		done += (size_t)ret;
358 	}
359 
360 	return ((ssize_t)done);
361 }
362