#ifdef HAVE_CONFIG_H
# include "config.h"
#endif

#include "socktool.h"
#include "xwrap.h"

#include <arpa/inet.h>
#include <errno.h>
#include <netinet/in.h>
#include <signal.h>
#include <stdbool.h>
#include <string.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>

typedef int (*SockNameGetter) (int, struct sockaddr*, socklen_t*);
static char* ras_socktool_get_name (int sockfd, SockNameGetter getter) {
	struct sockaddr_storage sock;
	socklen_t socklen = sizeof (sock);

	memset (&sock, 0, socklen);
	if ((*getter)(sockfd, SOCKADDR (&sock), &socklen) < 0) {
		return xstrdup ("invalid socket");
	}

	int domain = sock.ss_family;
	if (domain == AF_UNIX) {
		return xstrdup ("local process");
	}

	socklen_t ipstrlen;
	void* ipnet;
	char* ipstr;
	if (domain == AF_INET) {
		ipstrlen = INET_ADDRSTRLEN;
		ipnet = &(SOCKADDR_IN (&sock)->sin_addr);
	} else {
		ipstrlen = INET6_ADDRSTRLEN;
		ipnet = &(SOCKADDR_IN6 (&sock)->sin6_addr);
	}

	ipstr = xmalloc (ipstrlen);
	if (inet_ntop (domain, ipnet, ipstr, ipstrlen) == NULL) {
		free (ipstr);
		return xstrdup ("unknown address");
	}

	return ipstr;
}

char* ras_socktool_get_sockname (int sockfd) {
	return ras_socktool_get_name (sockfd, getsockname);
}

char* ras_socktool_get_peername (int sockfd) {
	return ras_socktool_get_name (sockfd, getpeername);
}

void ras_socktool_buffer_clear (RasBuffer* buf, bool initial) {
	buf->buf_start = 0;
	buf->buf_len = 0;

	if (!initial) {
		free (buf->buf_line);
	}
	buf->buf_line = NULL;
	buf->buf_line_len = 0;
	buf->buf_error = false;
	buf->buf_eof = false;
}

void ras_socktool_buffer_free_line (RasBuffer* buf) {
	free (buf->buf_line);
	buf->buf_line = NULL;
	buf->buf_line_len = 0;
}

char* ras_socktool_getline (int sockfd, RasBuffer* buf, int delim, int* len) {
	if (buf->buf_error || buf->buf_eof) {
		return NULL;
	}

	if (buf->buf_len == 0) {
		int rval = read (sockfd, buf->buf, RAS_BUFFER_SIZE);
		if (rval < 0) {
			if (errno != EAGAIN && errno != EINTR) {
				buf->buf_error = true;
			}
			return NULL;
		}
		if (rval == 0) {
			buf->buf_eof = true;
			return NULL;
		}
		buf->buf_start = 0;
		buf->buf_len = rval;
	}

	int i;
	for (i = 0; i < buf->buf_len; i++) {
		if (buf->buf[buf->buf_start + i] == delim) {
			break;
		}
	}

	int buf_line_start = buf->buf_line_len;
	buf->buf_line_len += i;
	buf->buf_line = xrealloc (buf->buf_line, buf->buf_line_len + 1);
	memcpy (buf->buf_line + buf_line_start, buf->buf + buf->buf_start, i);
	buf->buf_line[buf->buf_line_len] = '\0';

	/* remove CR if delim is LF and delim is found */
	if (i < buf->buf_len && delim == '\n' && buf->buf_line_len - 1 >= 0 &&
		buf->buf_line[buf->buf_line_len - 1] == '\r') {
		buf->buf_line[buf->buf_line_len - 1] = '\0';
		buf->buf_line_len--;
	}

	int buf_len_saved = buf->buf_len;
	buf->buf_start += i + 1;
	buf->buf_len -= i + 1;
	if (buf->buf_len <= 0) {
		buf->buf_start = 0;
		buf->buf_len = 0;
	}

	if (i < buf_len_saved) {
		/* delim is found */
		char* newstr = buf->buf_line;
		if (len != NULL) {
			*len = buf->buf_line_len;
		}
		buf->buf_line = NULL;
		buf->buf_line_len = 0;
		return newstr;
	}

	return NULL;
}

int ras_socktool_write_string (int sockfd, const char* str, size_t size) {
	ssize_t wtn = 0;
	if (size <= 0) {
		size = strlen (str);
	}
	size_t rem = size;
	while (rem > 0) {
		wtn = write (sockfd, str, rem);
		if (wtn < 0) {
			if (errno != EINTR && errno != EAGAIN) {
				break;
			}
			continue;
		}
		str += wtn;
		rem -= wtn;
	}

	rem = rem > 0 ? rem : 0;
	return size - rem;
}

void ras_socktool_reset_signals (void) {
	const int must_reset[] = { SIGHUP, SIGINT, SIGQUIT, SIGABRT,
		SIGFPE, SIGTERM, SIGCHLD, SIGTSTP, SIGTTIN, SIGTTOU };
	struct sigaction action = {
		.sa_handler = SIG_DFL,
		.sa_flags = 0
	};
	sigemptyset (&action.sa_mask);
#if defined(SIGRTMIN) && defined(SIGRTMAX)
	int signal_max = xmax (SIGRTMIN, SIGRTMAX);
	for (int i = 1; i < signal_max; i++) {
		sigaction (i, &action, NULL);
	}
#endif
	for (int i = 0; i < sizeof (must_reset) / sizeof (int); i++) {
		sigaction (must_reset[i], &action, NULL);
	}
}

int ras_socktool_getdtablesize (void) {
#ifdef HAVE_GETDTABLESIZE
	return getdtablesize ();
#else
    long rval = sysconf (_SC_OPEN_MAX);
    if (rval < 0) {
# ifdef _POSIX_OPEN_MAX
        rval = _POSIX_OPEN_MAX
# else
        rval = 20;
# endif
    }
    return rval;
#endif
}

void ras_socktool_close_except (int dontclose) {
	int dtsize = ras_socktool_getdtablesize ();
	for (int i = 0; i < dtsize; i++) {
		if (i != dontclose) {
			close (i);
		}
	}
}

void ras_socktool_exchange_data (int fd[2], RasSocktoolProgress prog_cb) {
	int nfds, active;
	fd_set rset, wset;
	fd_set rres, wres;
	RasBuffer buf[2];

	FD_ZERO (&rset);
	FD_SET (fd[0], &rset);
	FD_SET (fd[1], &rset);
	FD_ZERO (&wset);

	active = 2;
	nfds = xmax (fd[0], fd[1]) + 1;
	ras_socktool_buffer_clear (&buf[0], true);
	ras_socktool_buffer_clear (&buf[1], true);

	while (active && (*prog_cb) (fd, buf)) {
		rres = rset;
		wres = wset;
		if (select (nfds, &rres, &wres, NULL, NULL) >= 0) {
			for (int i = 0; i < 2; i++) {
				if (buf[i].buf_len) {
					/* read buffer full */
					if (FD_ISSET (fd[!i], &wres)) {
						int wb = write (fd[!i],
							buf[i].buf + buf[i].buf_start, buf[i].buf_len);
						if (wb > 0) {
							buf[i].buf_start += wb;
							buf[i].buf_len -= wb;
							if (!buf[i].buf_len) {
								FD_CLR (fd[!i], &wset);
								FD_SET (fd[i], &rset);
							}
						} else {
							if (wb < 0 && (errno == EAGAIN || errno == EINTR)) {
								continue;
							} else {
								FD_CLR (fd[!i], &wset);
								active = 0;
							}
						}
					}
				} else {
					/* read buffer empty */
					if (FD_ISSET (fd[i], &rres)) {
						int rb = read (fd[i], buf[i].buf, RAS_BUFFER_SIZE);
						if (rb > 0) {
							buf[i].buf_start = 0;
							buf[i].buf_len = rb;
							FD_CLR (fd[i], &rset);
							FD_SET (fd[!i], &wset);
						} else {
							if (rb < 0 && (errno == EAGAIN || errno == EINTR)) {
								continue;
							} else{
								FD_CLR (fd[i], &rset);
								active = 0;
							}
						}
					}
				}
			}
		}
	}
}