/*
 * CDDL HEADER START
 *
 * The contents of this file are subject to the terms of the
 * Common Development and Distribution License (the "License").
 * You may not use this file except in compliance with the License.
 *
 * You can obtain a copy of the license at usr/src/OPENSOLARIS.LICENSE
 * or http://www.opensolaris.org/os/licensing.
 * See the License for the specific language governing permissions
 * and limitations under the License.
 *
 * When distributing Covered Code, include this CDDL HEADER in each
 * file and include the License file at usr/src/OPENSOLARIS.LICENSE.
 * If applicable, add the following below this CDDL HEADER, with the
 * fields enclosed by brackets "[]" replaced with your own identifying
 * information: Portions Copyright [yyyy] [name of copyright owner]
 *
 * CDDL HEADER END
 */
/*
 * Copyright 2006 Sun Microsystems, Inc.  All rights reserved.
 * Use is subject to license terms.
 */

/*	Copyright (c) 1983, 1984, 1985, 1986, 1987, 1988, 1989 AT&T	*/
/*	  All Rights Reserved  	*/

/*
 * University Copyright- Copyright (c) 1982, 1986, 1988
 * The Regents of the University of California
 * All Rights Reserved
 *
 * University Acknowledgment- Portions of this document are derived from
 * software developed by the University of California, Berkeley, and its
 * contributors.
 */

#pragma ident	"%Z%%M%	%I%	%E% SMI"

/*
 * TFTP User Program -- Protocol Machines
 */
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/stat.h>

#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>
#include <stddef.h>
#include <inttypes.h>

#include "tftpcommon.h"
#include "tftpprivate.h"

static char	*blksize_str(void);
static char	*timeout_str(void);
static char	*tsize_str(void);
static int	blksize_handler(char *);
static int	timeout_handler(char *);
static int	tsize_handler(char *);
static int	add_options(char *, char *);
static int	process_oack(tftpbuf *, int);
static void	nak(int);
static void	startclock(void);
static void	stopclock(void);
static void	printstats(char *, off_t);
static int	makerequest(int, char *, struct tftphdr *, char *);
static void	tpacket(char *, struct tftphdr *, int);

static struct options {
	char	*opt_name;
	char	*(*opt_str)(void);
	int	(*opt_handler)(char *);
} options[] = {
	{ "blksize",	blksize_str, blksize_handler },
	{ "timeout",	timeout_str, timeout_handler },
	{ "tsize",	tsize_str, tsize_handler },
	{ NULL }
};

static char		optbuf[MAX_OPTVAL_LEN];
static boolean_t	tsize_set;

static tftpbuf	ackbuf;
static int	timeout;
static off_t	tsize;
static jmp_buf	timeoutbuf;

int	blocksize = SEGSIZE;	/* Number of data bytes in a DATA packet */

/*ARGSUSED*/
static void
timer(int signum)
{
	timeout += rexmtval;
	if (timeout >= maxtimeout) {
		(void) fputs("Transfer timed out.\n", stderr);
		longjmp(toplevel, -1);
	}
	(void) signal(SIGALRM, timer);
	longjmp(timeoutbuf, 1);
}

/*
 * Send the requested file.
 */
void
tftp_sendfile(int fd, char *name, char *mode)
{
	struct tftphdr *ap;	/* data and ack packets */
	struct tftphdr *dp;
	int count = 0, size, n;
	ushort_t block = 0;
	off_t amount = 0;
	struct sockaddr_in6 from;
	socklen_t fromlen;
	int convert;	/* true if doing nl->crlf conversion */
	FILE *file;
	struct stat statb;
	int errcode;

	startclock();	/* start stat's clock */
	dp = r_init();	/* reset fillbuf/read-ahead code */
	ap = &ackbuf.tb_hdr;
	file = fdopen(fd, "r");
	convert = (strcmp(mode, "netascii") == 0);

	tsize_set = ((tsize_opt != 0) && !convert && (fstat(fd, &statb) == 0));
	if (tsize_set)
		tsize = statb.st_size;

	do {
		(void) signal(SIGALRM, timer);
		if (count == 0) {
			if ((size = makerequest(WRQ, name, dp, mode)) == -1) {
				(void) fprintf(stderr,
				    "tftp: Error: Write request packet too "
				    "big\n");
				(void) fclose(file);
				return;
			}
			size -= 4;
		} else {
			size = readit(file, &dp, convert);
			if (size < 0) {
				nak(errno + 100);
				break;
			}
			dp->th_opcode = htons((ushort_t)DATA);
			dp->th_block = htons((ushort_t)block);
		}
		timeout = 0;
		(void) setjmp(timeoutbuf);
		if (trace)
			tpacket("sent", dp, size + 4);
		n = sendto(f, dp, size + 4, 0,
		    (struct sockaddr *)&sin6, sizeof (sin6));
		if (n != size + 4) {
			perror("tftp: sendto");
			goto abort;
		}
		/* Can't read-ahead first block as OACK may change blocksize */
		if (count != 0)
			read_ahead(file, convert);
		(void) alarm(rexmtval);
		for (; ; ) {
			(void) sigrelse(SIGALRM);
			do {
				fromlen = (socklen_t)sizeof (from);
				n = recvfrom(f, ackbuf.tb_data,
				    sizeof (ackbuf.tb_data), 0,
				    (struct sockaddr *)&from, &fromlen);
				if (n < 0) {
					perror("tftp: recvfrom");
					goto abort;
				}
			} while (n < offsetof(struct tftphdr, th_data));
			(void) sighold(SIGALRM);
			sin6.sin6_port = from.sin6_port;   /* added */
			if (trace)
				tpacket("received", ap, n);
			/* should verify packet came from server */
			ap->th_opcode = ntohs(ap->th_opcode);
			if (ap->th_opcode == ERROR) {
				ap->th_code = ntohs(ap->th_code);
				(void) fprintf(stderr,
					"Error code %d", ap->th_code);
				if (n > offsetof(struct tftphdr, th_data))
					(void) fprintf(stderr, ": %.*s", n -
					    offsetof(struct tftphdr, th_data),
					    ap->th_msg);
				(void) fputc('\n', stderr);
				goto abort;
			}
			if ((count == 0) && (ap->th_opcode == OACK)) {
				errcode = process_oack(&ackbuf, n);
				if (errcode >= 0) {
					nak(errcode);
					(void) fputs("Rejected OACK\n",
					    stderr);
					goto abort;
				}
				break;
			}
			if (ap->th_opcode == ACK) {
				ap->th_block = ntohs(ap->th_block);
				if (ap->th_block == block) {
					break;
				}
				/*
				 * Never resend the current DATA packet on
				 * receipt of a duplicate ACK, doing so would
				 * cause the "Sorcerer's Apprentice Syndrome".
				 */
			}
		}
		cancel_alarm();
		if (count > 0)
			amount += size;
		block++;
		count++;
	} while (size == blocksize || count == 1);
abort:
	cancel_alarm();
	(void) fclose(file);
	stopclock();
	if (amount > 0)
		printstats("Sent", amount);
}

/*
 * Receive a file.
 */
void
tftp_recvfile(int fd, char *name, char *mode)
{
	struct tftphdr *ap;
	struct tftphdr *dp;
	ushort_t block = 1;
	int n, size;
	unsigned long amount = 0;
	struct sockaddr_in6 from;
	socklen_t fromlen;
	boolean_t firsttrip = B_TRUE;
	FILE *file;
	int convert;	/* true if converting crlf -> lf */
	int errcode;

	startclock();
	dp = w_init();
	ap = &ackbuf.tb_hdr;
	file = fdopen(fd, "w");
	convert = (strcmp(mode, "netascii") == 0);

	tsize_set = (tsize_opt != 0);
	if (tsize_set)
		tsize = 0;

	if ((size = makerequest(RRQ, name, ap, mode)) == -1) {
		(void) fprintf(stderr,
		    "tftp: Error: Read request packet too big\n");
		(void) fclose(file);
		return;
	}

	do {
		(void) signal(SIGALRM, timer);
		if (firsttrip) {
			firsttrip = B_FALSE;
		} else {
			ap->th_opcode = htons((ushort_t)ACK);
			ap->th_block = htons((ushort_t)(block));
			size = 4;
			block++;
		}

send_oack_ack:
		timeout = 0;
		(void) setjmp(timeoutbuf);
send_ack:
		if (trace)
			tpacket("sent", ap, size);
		if (sendto(f, ackbuf.tb_data, size, 0, (struct sockaddr *)&sin6,
		    sizeof (sin6)) != size) {
			(void) alarm(0);
			perror("tftp: sendto");
			goto abort;
		}
		if (write_behind(file, convert) < 0) {
			nak(errno + 100);
			goto abort;
		}
		(void) alarm(rexmtval);
		for (; ; ) {
			(void) sigrelse(SIGALRM);
			do  {
				fromlen = (socklen_t)sizeof (from);
				n = recvfrom(f, dp, blocksize + 4, 0,
				    (struct sockaddr *)&from, &fromlen);
				if (n < 0) {
					perror("tftp: recvfrom");
					goto abort;
				}
			} while (n < offsetof(struct tftphdr, th_data));
			(void) sighold(SIGALRM);
			sin6.sin6_port = from.sin6_port;   /* added */
			if (trace)
				tpacket("received", dp, n);
			/* should verify client address */
			dp->th_opcode = ntohs(dp->th_opcode);
			if (dp->th_opcode == ERROR) {
				dp->th_code = ntohs(dp->th_code);
				(void) fprintf(stderr, "Error code %d",
				    dp->th_code);
				if (n > offsetof(struct tftphdr, th_data))
					(void) fprintf(stderr, ": %.*s", n -
					    offsetof(struct tftphdr, th_data),
					    dp->th_msg);
				(void) fputc('\n', stderr);
				goto abort;
			}
			if ((block == 1) && (dp->th_opcode == OACK)) {
				errcode = process_oack((tftpbuf *)dp, n);
				if (errcode >= 0) {
					cancel_alarm();
					nak(errcode);
					(void) fputs("Rejected OACK\n",
					    stderr);
					(void) fclose(file);
					return;
				}
				ap->th_opcode = htons((ushort_t)ACK);
				ap->th_block = htons(0);
				size = 4;
				goto send_oack_ack;
			}
			if (dp->th_opcode == DATA) {
				int j;

				dp->th_block = ntohs(dp->th_block);
				if (dp->th_block == block) {
					break;	/* have next packet */
				}
				/*
				 * On an error, try to synchronize
				 * both sides.
				 */
				j = synchnet(f);
				if (j < 0) {
					perror("tftp: recvfrom");
					goto abort;
				}
				if ((j > 0) && trace) {
					(void) printf("discarded %d packets\n",
					    j);
				}
				if (dp->th_block == (block-1)) {
					goto send_ack;  /* resend ack */
				}
			}
		}
		cancel_alarm();
		size = writeit(file, &dp, n - 4, convert);
		if (size < 0) {
			nak(errno + 100);
			goto abort;
		}
		amount += size;
	} while (size == blocksize);

	cancel_alarm();
	if (write_behind(file, convert) < 0) {	/* flush last buffer */
		nak(errno + 100);
		goto abort;
	}
	n = fclose(file);
	file = NULL;
	if (n == EOF) {
		nak(errno + 100);
		goto abort;
	}

	/* ok to ack, since user has seen err msg */
	ap->th_opcode = htons((ushort_t)ACK);
	ap->th_block = htons((ushort_t)block);
	if (trace)
		tpacket("sent", ap, 4);
	if (sendto(f, ackbuf.tb_data, 4, 0,
	    (struct sockaddr *)&sin6, sizeof (sin6)) != 4)
		perror("tftp: sendto");

abort:
	cancel_alarm();
	if (file != NULL)
		(void) fclose(file);
	stopclock();
	if (amount > 0)
		printstats("Received", amount);
}

static int
makerequest(int request, char *name, struct tftphdr *tp, char *mode)
{
	char *cp, *cpend;
	int len;

	tp->th_opcode = htons((ushort_t)request);
	cp = (char *)&tp->th_stuff;

	/* Maximum size of a request packet is 512 bytes (RFC 2347) */
	cpend = (char *)tp + SEGSIZE;

	len = strlcpy(cp, name, cpend - cp) + 1;
	cp += len;
	if (cp > cpend)
		return (-1);

	len = strlcpy(cp, mode, cpend - cp) + 1;
	cp += len;
	if (cp > cpend)
		return (-1);

	len = add_options(cp, cpend);
	if (len == -1)
		return (-1);
	cp += len;

	return (cp - (char *)tp);
}

/*
 * Return the blksize option value string to include in the request packet.
 */
static char *
blksize_str(void)
{
	blocksize = SEGSIZE;
	if (blksize == 0)
		return (NULL);

	(void) snprintf(optbuf, sizeof (optbuf), "%d", blksize);
	return (optbuf);
}

/*
 * Return the timeout option value string to include in the request packet.
 */
static char *
timeout_str(void)
{
	if (srexmtval == 0)
		return (NULL);

	(void) snprintf(optbuf, sizeof (optbuf), "%d", srexmtval);
	return (optbuf);
}

/*
 * Return the tsize option value string to include in the request packet.
 */
static char *
tsize_str(void)
{
	if (tsize_set == B_FALSE)
		return (NULL);

	(void) snprintf(optbuf, sizeof (optbuf), OFF_T_FMT, tsize);
	return (optbuf);
}

/*
 * Validate and action the blksize option value string from the OACK packet.
 * Returns -1 on success or an error code on failure.
 */
static int
blksize_handler(char *optstr)
{
	char *endp;
	int value;

	/* Make sure the option was requested */
	if (blksize == 0)
		return (EOPTNEG);
	errno = 0;
	value = (int)strtol(optstr, &endp, 10);
	if (errno != 0 || value < MIN_BLKSIZE || value > blksize ||
	    *endp != '\0')
		return (EOPTNEG);
	blocksize = value;
	return (-1);
}

/*
 * Validate and action the timeout option value string from the OACK packet.
 * Returns -1 on success or an error code on failure.
 */
static int
timeout_handler(char *optstr)
{
	char *endp;
	int value;

	/* Make sure the option was requested */
	if (srexmtval == 0)
		return (EOPTNEG);
	errno = 0;
	value = (int)strtol(optstr, &endp, 10);
	if (errno != 0 || value != srexmtval || *endp != '\0')
		return (EOPTNEG);
	/*
	 * Nothing to set, client and server retransmission intervals are
	 * set separately in the client.
	 */
	return (-1);
}

/*
 * Validate and action the tsize option value string from the OACK packet.
 * Returns -1 on success or an error code on failure.
 */
static int
tsize_handler(char *optstr)
{
	char *endp;
	longlong_t value;

	/* Make sure the option was requested */
	if (tsize_set == B_FALSE)
		return (EOPTNEG);
	errno = 0;
	value = strtoll(optstr, &endp, 10);
	if (errno != 0 || value < 0 || *endp != '\0')
		return (EOPTNEG);
#if _FILE_OFFSET_BITS == 32
	if (value > MAXOFF_T)
		return (ENOSPACE);
#endif
	/*
	 * Don't bother checking the tsize value we specified in a write
	 * request is echoed back in the OACK.
	 */
	if (tsize == 0)
		tsize = value;
	return (-1);
}

/*
 * Add TFTP options to a request packet.
 */
static int
add_options(char *obuf, char *obufend)
{
	int i;
	char *cp, *ostr;

	cp = obuf;
	for (i = 0; options[i].opt_name != NULL; i++) {
		ostr = options[i].opt_str();
		if (ostr != NULL) {
			cp += strlcpy(cp, options[i].opt_name, obufend - cp)
			    + 1;
			if (cp > obufend)
				return (-1);

			cp += strlcpy(cp, ostr, obufend - cp) + 1;
			if (cp > obufend)
				return (-1);
		}
	}
	return (cp - obuf);
}

/*
 * Process OACK packet sent by server in response to options in the request
 * packet. Returns -1 on success or an error code on failure.
 */
static int
process_oack(tftpbuf *oackbuf, int n)
{
	char *cp, *oackend, *optname, *optval;
	struct tftphdr *oackp;
	int i, errcode;

	oackp = &oackbuf->tb_hdr;
	cp = (char *)&oackp->th_stuff;
	oackend = (char *)oackbuf + n;

	while (cp < oackend) {
		optname = cp;
		if ((optval = next_field(optname, oackend)) == NULL)
			return (EOPTNEG);
		if ((cp = next_field(optval, oackend)) == NULL)
			return (EOPTNEG);
		for (i = 0; options[i].opt_name != NULL; i++) {
			if (strcasecmp(optname, options[i].opt_name) == 0)
				break;
		}
		if (options[i].opt_name == NULL)
			return (EOPTNEG);
		errcode = options[i].opt_handler(optval);
		if (errcode >= 0)
			return (errcode);
	}
	return (-1);
}

/*
 * Send a nak packet (error message).
 * Error code passed in is one of the
 * standard TFTP codes, or a UNIX errno
 * offset by 100.
 */
static void
nak(int error)
{
	struct tftphdr *tp;
	int length;
	struct errmsg *pe;

	tp = &ackbuf.tb_hdr;
	tp->th_opcode = htons((ushort_t)ERROR);
	tp->th_code = htons((ushort_t)error);
	for (pe = errmsgs; pe->e_code >= 0; pe++)
		if (pe->e_code == error)
			break;
	if (pe->e_code < 0) {
		pe->e_msg = strerror(error - 100);
		tp->th_code = EUNDEF;
	}
	(void) strlcpy(tp->th_msg, pe->e_msg,
	    sizeof (ackbuf) - sizeof (struct tftphdr));
	length = strlen(pe->e_msg) + 4;
	if (trace)
		tpacket("sent", tp, length);
	if (sendto(f, ackbuf.tb_data, length, 0,
	    (struct sockaddr *)&sin6, sizeof (sin6)) != length)
		perror("nak");
}

static void
tpacket(char *s, struct tftphdr *tp, int n)
{
	static char *opcodes[] = \
		{ "#0", "RRQ", "WRQ", "DATA", "ACK", "ERROR", "OACK" };
	char *cp, *file, *mode;
	ushort_t op = ntohs(tp->th_opcode);
	char *tpend;

	if (op < RRQ || op > OACK)
		(void) printf("%s opcode=%x ", s, op);
	else
		(void) printf("%s %s ", s, opcodes[op]);

	switch (op) {
	case RRQ:
	case WRQ:
		tpend = (char *)tp + n;
		n -= sizeof (tp->th_opcode);
		file = (char *)&tp->th_stuff;
		if ((mode = next_field(file, tpend)) == NULL) {
			(void) printf("<file=%.*s>\n", n, file);
			break;
		}
		n -= mode - file;
		if ((cp = next_field(mode, tpend)) == NULL) {
			(void) printf("<file=%s, mode=%.*s>\n", file, n, mode);
			break;
		}
		(void) printf("<file=%s, mode=%s", file, mode);
		n -= cp - mode;
		if (n > 0) {
			(void) printf(", options: ");
			print_options(stdout, cp, n);
		}
		(void) puts(">");
		break;

	case DATA:
		(void) printf("<block=%d, %d bytes>\n", ntohs(tp->th_block),
		    n - sizeof (tp->th_opcode) - sizeof (tp->th_block));
		break;

	case ACK:
		(void) printf("<block=%d>\n", ntohs(tp->th_block));
		break;

	case OACK:
		(void) printf("<options: ");
		print_options(stdout, (char *)&tp->th_stuff,
		    n - sizeof (tp->th_opcode));
		(void) puts(">");
		break;

	case ERROR:
		(void) printf("<code=%d", ntohs(tp->th_code));
		n = n - sizeof (tp->th_opcode) - sizeof (tp->th_code);
		if (n > 0)
			(void) printf(", msg=%.*s", n, tp->th_msg);
		(void) puts(">");
		break;
	}
}

static hrtime_t	tstart, tstop;

static void
startclock(void)
{
	tstart = gethrtime();
}

static void
stopclock(void)
{
	tstop = gethrtime();
}

static void
printstats(char *direction, off_t amount)
{
	hrtime_t	delta, tenths;

	delta = tstop - tstart;
	tenths = delta / (NANOSEC / 10);
	(void) printf("%s " OFF_T_FMT " bytes in %" PRId64 ".%" PRId64
	    " seconds", direction, amount, tenths / 10, tenths % 10);
	if (verbose)
		(void) printf(" [%" PRId64 " bits/sec]\n",
		    ((hrtime_t)amount * 8 * NANOSEC) / delta);
	else
		(void) putchar('\n');
}