/*
 * This file and its contents are supplied under the terms of the
 * Common Development and Distribution License ("CDDL"), version 1.0.
 * You may only use this file in accordance with the terms of version
 * 1.0 of the CDDL.
 *
 * A full copy of the text of the CDDL should have accompanied this
 * source.  A copy of the CDDL is also available via the Internet at
 * http://www.illumos.org/license/CDDL.
 */

/*
 * Copyright 2020 OmniOS Community Edition (OmniOSce) Association.
 */

/*
 * Test ancillary data receipt via recvmsg()
 */

#include <stdio.h>
#include <errno.h>
#include <fcntl.h>
#include <signal.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>
#include <unistd.h>

#include <sys/types.h>
#include <sys/param.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/wait.h>
#include <pthread.h>
#include <err.h>

static boolean_t debug;
static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
static pthread_cond_t cv = PTHREAD_COND_INITIALIZER;
static pthread_mutex_t cmutex = PTHREAD_MUTEX_INITIALIZER;
static pthread_cond_t ccv = PTHREAD_COND_INITIALIZER;
static boolean_t server_ready = _B_FALSE;
static boolean_t client_done = _B_FALSE;

static in_addr_t testip;

#define	DEBUG(x) if (debug) printf x

#define	TESTPORT	32123

#define	RT_RECVTOS	0x1
#define	RT_RECVTTL	0x2
#define	RT_RECVPKTINFO	0x4
#define	RT_RECVMASK	0x7

#define	RT_SETTOS	0x10
#define	RT_SETTTL	0x20
#define	RT_STREAM	0x40
#define	RT_SKIP		0x80

typedef struct recvmsg_test {
	char *name;		/* Name of the test */
	uint8_t tos;		/* TOS to set */
	uint8_t ttl;		/* TTL to set */
	uint8_t flags;		/* Test flags, RT_ */
} recvmsg_test_t;

static recvmsg_test_t tests[] = {
	{
		.name = "baseline",
		.flags = 0,
	},

	/* Combinations of receive flags */
	{
		.name = "recv TOS",
		.flags = RT_RECVTOS,
	},

	{
		.name = "recv TTL",
		.flags = RT_RECVTTL,
	},

	{
		.name = "recv PKTINFO",
		.flags = RT_RECVPKTINFO,
	},

	{
		.name = "recv TOS,TTL",
		.flags = RT_RECVTOS | RT_RECVTTL,
	},

	{
		.name = "recv TTL,PKTINFO",
		.flags = RT_RECVTTL | RT_RECVPKTINFO,
	},

	{
		.name = "recv TOS,PKTINFO",
		.flags = RT_RECVTOS | RT_RECVPKTINFO,
	},

	{
		.name = "recv TOS,TTL,PKTINFO",
		.flags = RT_RECVTOS | RT_RECVTTL | RT_RECVPKTINFO,
	},

	/* Manually set TTL and TOS */

	{
		.name = "set TOS,TTL",
		.flags = RT_SETTOS | RT_SETTTL,
		.ttl = 11,
		.tos = 0xe0
	},

	{
		.name = "set/recv TOS,TTL",
		.flags = RT_SETTOS | RT_SETTTL | RT_RECVTOS | RT_RECVTTL,
		.ttl = 32,
		.tos = 0x48
	},

	{
		.name = "set TOS,TTL, recv PKTINFO",
		.flags = RT_SETTOS | RT_SETTTL | RT_RECVPKTINFO,
		.ttl = 173,
		.tos = 0x78
	},

	{
		.name = "set TOS,TTL, recv TOS,TTL,PKTINFO",
		.flags = RT_SETTOS | RT_SETTTL | RT_RECVTOS | RT_RECVTTL |
		    RT_RECVPKTINFO,
		.ttl = 54,
		.tos = 0x90
	},

	/* STREAM socket */

	{
		.name = "STREAM set TOS",
		.flags = RT_STREAM | RT_SETTOS,
		.tos = 0xe0
	},

	/*
	 * The ancillary data are not returned for the loopback TCP path,
	 * so these tests are skipped by default.
	 * To run them, use two different zones (or machines) and run:
	 *	recvmsg.64 -s 'test name'
	 * on the first, and:
	 *	recvmsg.64 -c <first machine IP> 'test name'
	 * on the second.
	 */
	{
		.name = "STREAM recv TOS",
		.flags = RT_STREAM | RT_RECVTOS | RT_SKIP,
	},

	{
		.name = "STREAM set/recv TOS",
		.flags = RT_STREAM | RT_SETTOS | RT_RECVTOS | RT_SKIP,
		.tos = 0x48
	},

	/* End of tests */

	{
		.name = NULL
	}
};

static boolean_t
servertest(recvmsg_test_t *t)
{
	struct sockaddr_in addr;
	boolean_t pass = _B_TRUE;
	int sockfd, readfd, acceptfd = -1, c = 1;

	DEBUG(("\nserver %s: starting\n", t->name));

	sockfd = socket(AF_INET,
	    t->flags & RT_STREAM ? SOCK_STREAM : SOCK_DGRAM, 0);
	if (sockfd == -1)
		err(EXIT_FAILURE, "failed to create server socket");

	addr.sin_family = AF_INET;
	addr.sin_addr.s_addr = INADDR_ANY;
	addr.sin_port = htons(TESTPORT);

	if (bind(sockfd, (struct sockaddr *)&addr, sizeof (addr)) == -1)
		err(EXIT_FAILURE, "server socket bind failed");

	if (t->flags & RT_RECVTOS) {
		DEBUG((" : setting RECVTOS\n"));
		if (setsockopt(sockfd, IPPROTO_IP, IP_RECVTOS, &c,
		    sizeof (c)) == -1) {
			printf("[FAIL] %s - "
			    "couldn't set TOS on server socket: %s\n",
			    t->name, strerror(errno));
			pass = _B_FALSE;
		}
	}

	if (t->flags & RT_RECVTTL) {
		DEBUG((" : setting RECVTTL\n"));
		if (setsockopt(sockfd, IPPROTO_IP, IP_RECVTTL, &c,
		    sizeof (c)) == -1) {
			printf("[FAIL] %s - "
			    "couldn't set TTL on server socket: %s\n",
			    t->name, strerror(errno));
			pass = _B_FALSE;
		}
	}

	if (t->flags & RT_RECVPKTINFO) {
		DEBUG((" : setting RECVPKTINFO\n"));
		if (setsockopt(sockfd, IPPROTO_IP, IP_PKTINFO, &c,
		    sizeof (c)) == -1) {
			printf("[FAIL] %s - "
			    "couldn't set PKTINFO on server socket: %s\n",
			    t->name, strerror(errno));
			pass = _B_FALSE;
		}
	}

	if (t->flags & RT_STREAM) {
		if (listen(sockfd, 1) == -1)
			err(EXIT_FAILURE, "Could not listen on sever socket");
	}

	/* Signal the client that the server is ready for the next test */
	if (debug)
		printf(" : signalling client\n");
	(void) pthread_mutex_lock(&mutex);
	server_ready = _B_TRUE;
	(void) pthread_cond_signal(&cv);
	(void) pthread_mutex_unlock(&mutex);

	if (t->flags & RT_STREAM) {
		struct sockaddr_in caddr;
		socklen_t sl = sizeof (caddr);

		if ((acceptfd = accept(sockfd, (struct sockaddr *)&caddr,
		    &sl)) == -1) {
			err(EXIT_FAILURE, "socket accept failed");
		}
		readfd = acceptfd;
	} else {
		readfd = sockfd;
	}

	/* Receive the datagram */

	struct msghdr msg;
	char buf[0x100];
	char cbuf[CMSG_SPACE(0x400)];
	struct iovec iov[1] = {0};
	ssize_t r;

	iov[0].iov_base = buf;
	iov[0].iov_len = sizeof (buf);

	bzero(&msg, sizeof (msg));
	msg.msg_iov = iov;
	msg.msg_iovlen = 1;
	msg.msg_control = cbuf;
	msg.msg_controllen = sizeof (cbuf);

	DEBUG((" : waiting for message\n"));

	r = recvmsg(readfd, &msg, 0);
	if (r <= 0) {
		printf("[FAIL] %s - recvmsg returned %d (%s)\n",
		    t->name, r, strerror(errno));
		pass = _B_FALSE;
		goto out;
	}

	DEBUG((" : recvmsg returned %d (flags=0x%x, controllen=%d)\n",
	    r, msg.msg_flags, msg.msg_controllen));

	if (r != strlen(t->name)) {
		printf("[FAIL] %s - got '%.*s' (%d bytes), expected '%s'\n",
		    t->name, r, buf, r, t->name);
		pass = _B_FALSE;
	}

	DEBUG((" : Received '%.*s'\n", r, buf));

	if (msg.msg_flags != 0) {
		printf("[FAIL] %s - received flags 0x%x\n",
		    t->name, msg.msg_flags);
		pass = _B_FALSE;
	}

	uint8_t flags = 0;

	for (struct cmsghdr *cm = CMSG_FIRSTHDR(&msg); cm != NULL;
	    cm = CMSG_NXTHDR(&msg, cm)) {
		uint8_t d;

		DEBUG((" : >> Got cmsg %x/%x - length %u\n",
		    cm->cmsg_level, cm->cmsg_type, cm->cmsg_len));

		if (cm->cmsg_level != IPPROTO_IP)
			continue;

		switch (cm->cmsg_type) {
		case IP_PKTINFO:
			flags |= RT_RECVPKTINFO;
			if (debug) {
				struct in_pktinfo *pi =
				    (struct in_pktinfo *)CMSG_DATA(cm);
				printf(" : ifIndex: %u\n", pi->ipi_ifindex);
			}
			break;
		case IP_RECVTTL:
			if (cm->cmsg_len != CMSG_LEN(sizeof (uint8_t))) {
				printf(
				    "[FAIL] %s - cmsg_len was %u expected %u\n",
				    t->name, cm->cmsg_len,
				    CMSG_LEN(sizeof (uint8_t)));
				pass = _B_FALSE;
				break;
			}
			flags |= RT_RECVTTL;
			memcpy(&d, CMSG_DATA(cm), sizeof (d));
			DEBUG((" : RECVTTL = %u\n", d));
			if (t->flags & RT_SETTTL && d != t->ttl) {
				printf("[FAIL] %s - TTL was %u, expected %u\n",
				    t->name, d, t->ttl);
				pass = _B_FALSE;
			}
			break;
		case IP_RECVTOS:
			if (cm->cmsg_len != CMSG_LEN(sizeof (uint8_t))) {
				printf(
				    "[FAIL] %s - cmsg_len was %u expected %u\n",
				    t->name, cm->cmsg_len,
				    CMSG_LEN(sizeof (uint8_t)));
				pass = _B_FALSE;
				break;
			}
			flags |= RT_RECVTOS;
			memcpy(&d, CMSG_DATA(cm), sizeof (d));
			DEBUG((" : RECVTOS = %u\n", d));
			if (t->flags & RT_SETTOS && d != t->tos) {
				printf("[FAIL] %s - TOS was %u, expected %u\n",
				    t->name, d, t->tos);
				pass = _B_FALSE;
			}
			break;
		}
	}

	if ((t->flags & RT_RECVMASK) != flags) {
		printf("[FAIL] %s - Did not receive everything expected, "
		    "flags %#x vs. %#x\n", t->name,
		    flags, t->flags & RT_RECVMASK);
		pass = _B_FALSE;
	}

	/* Wait for the client to finish */
	(void) pthread_mutex_lock(&cmutex);
	while (!client_done)
		(void) pthread_cond_wait(&ccv, &cmutex);
	client_done = _B_FALSE;
	(void) pthread_mutex_unlock(&cmutex);

out:
	if (acceptfd != -1)
		(void) close(acceptfd);
	(void) close(sockfd);

	if (pass)
		printf("[PASS] %s\n", t->name);

	return (pass);
}

static int
server(const char *test)
{
	int ret = EXIT_SUCCESS;
	recvmsg_test_t *t;

	for (t = tests; t->name != NULL; t++) {
		if (test != NULL) {
			if (strcmp(test, t->name) != 0)
				continue;
			client_done = _B_TRUE;
			return (servertest(t));
		}
		if (t->flags & RT_SKIP) {
			printf("[SKIP] %s - (requires two separate zones)\n",
			    t->name);
			continue;
		}
		if (!servertest(t))
			ret = EXIT_FAILURE;
	}

	return (ret);
}

static void
clienttest(recvmsg_test_t *t)
{
	struct sockaddr_in addr;
	int s, ret;

	DEBUG(("client %s: starting\n", t->name));

	s = socket(AF_INET, t->flags & RT_STREAM ? SOCK_STREAM : SOCK_DGRAM, 0);
	if (s == -1)
		err(EXIT_FAILURE, "failed to create client socket");

	addr.sin_family = AF_INET;
	addr.sin_addr.s_addr = testip;
	addr.sin_port = htons(TESTPORT);

	if (t->flags & RT_STREAM) {
		if (connect(s, (struct sockaddr *)&addr, sizeof (addr)) == -1)
			err(EXIT_FAILURE, "failed to connect to server");
	}

	if (t->flags & RT_SETTOS) {
		int c = t->tos;

		DEBUG(("client %s: setting TOS = 0x%x\n", t->name, c));
		if (setsockopt(s, IPPROTO_IP, IP_TOS, &c, sizeof (c)) == -1)
			err(EXIT_FAILURE, "could not set TOS on client socket");
	}

	if (t->flags & RT_SETTTL) {
		int c = t->ttl;

		DEBUG(("client %s: setting TTL = 0x%x\n", t->name, c));
		if (setsockopt(s, IPPROTO_IP, IP_TTL, &c, sizeof (c)) == -1)
			err(EXIT_FAILURE, "could not set TTL on client socket");
	}

	DEBUG(("client %s: sending\n", t->name));

	if (t->flags & RT_STREAM) {
		ret = send(s, t->name, strlen(t->name), 0);
		shutdown(s, SHUT_RDWR);
	} else {
		ret = sendto(s, t->name, strlen(t->name), 0,
		    (struct sockaddr *)&addr, sizeof (addr));
	}

	if (ret == -1)
		err(EXIT_FAILURE, "sendto failed to send data to server");

	DEBUG(("client %s: done\n", t->name));

	close(s);
}

static void *
client(void *arg)
{
	char *test = (char *)arg;
	recvmsg_test_t *t;

	for (t = tests; t->name != NULL; t++) {
		if (test != NULL) {
			if (strcmp(test, t->name) != 0)
				continue;
			clienttest(t);
			return (NULL);
		}
		if (t->flags & RT_SKIP)
			continue;
		/* Wait for the server to be ready to receive */
		(void) pthread_mutex_lock(&mutex);
		while (!server_ready)
			(void) pthread_cond_wait(&cv, &mutex);
		server_ready = _B_FALSE;
		(void) pthread_mutex_unlock(&mutex);
		clienttest(t);
		/* Tell the server we are done */
		(void) pthread_mutex_lock(&cmutex);
		client_done = _B_TRUE;
		(void) pthread_cond_signal(&ccv);
		(void) pthread_mutex_unlock(&cmutex);
	}

	return (NULL);
}

int
main(int argc, const char **argv)
{
	int ret = EXIT_SUCCESS;
	pthread_t cthread;

	if (argc > 1 && strcmp(argv[1], "-d") == 0) {
		debug = _B_TRUE;
		argc--, argv++;
	}

	/* -c <server IP> <test name> */
	if (argc == 4 && strcmp(argv[1], "-c") == 0) {
		testip = inet_addr(argv[2]);
		printf("TEST IP: %s\n", argv[2]);
		if (testip == INADDR_NONE) {
			err(EXIT_FAILURE,
			    "Could not parse destination IP address");
		}
		client((void *)argv[3]);
		return (ret);
	}

	/* -s <test name> */
	if (argc == 3 && strcmp(argv[1], "-s") == 0)
		return (server(argv[2]));

	testip = inet_addr("127.0.0.1");
	if (testip == INADDR_NONE)
		err(EXIT_FAILURE, "Could not parse destination IP address");

	if (pthread_create(&cthread, NULL, client, NULL) == -1)
		err(EXIT_FAILURE, "Could not create client thread");

	ret = server(NULL);

	if (pthread_join(cthread, NULL) != 0)
		err(EXIT_FAILURE, "join client thread failed");

	return (ret);
}