/*
  $Id: mp.c,v 1.12 1996/10/22 00:49:56 luik Exp $

  mp.c - input/output multiplexing routines for omirrd.
  Copyright (C) 1996, Andreas Luik, <luik@pharao.s.bawue.de>.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 1, or (at your option)
  any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software
  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

*/

#define NEED_TIME		/* include <sys/time.h> and <time.h> */
#include "common.h"

#if defined(RCSID) && !defined(lint)
static char rcsid[] UNUSED__ = "$Id: mp.c,v 1.12 1996/10/22 00:49:56 luik Exp $";
#endif /* defined(RCSID) && !defined(lint) */

#include <errno.h>
#include <stdlib.h>
#include <setjmp.h>
#include "sig.h"
#include "debug.h"
#include "mp.h"


#ifndef HAVE_SIGSETJMP
#define sigjmp_buf	jmp_buf
#define sigsetjmp(E,S)	setjmp(E)
#define siglongjmp(E,V)	longjmp(E, V)
#endif


/* mp_context_head.next points to linked list of mp contexts */
static MpContextRec mp_context_head;



/*
 * Multiplex context allocation and deallocation.
 */

MpContext mpAlloc(void)
{
    MpContext context;
    if ((context = calloc(sizeof(MpContextRec), 1))) {
	context->nfds = 0;
	FD_ZERO(&context->readfds);
	FD_ZERO(&context->writefds);
	FD_ZERO(&context->exceptfds);
	context->next = mp_context_head.next;
	mp_context_head.next = context;
    }
    return context;
}

void mpFree(MpContext context)
{
    MpContext mpc;

    for (mpc = &mp_context_head; mpc; mpc = mpc->next) {
	if (mpc->next == context) {
	    mpc->next = context->next;
	    break;
	}
    }

    free(context);
}



/*
 * I/O callback handling.
 */

static MpContext mpAddCallback(MpContext context, int fd,
			       MpCallbackFunc callback, void *data,
			       fd_set *fds, MpCallback cb)
{
    if (fd >= FD_SETSIZE || cb->callback)
	return NULL /* XXX ERROR */;

    cb->callback = callback;
    cb->data = data;

    if (fds) {
	if (fd + 1 > context->nfds)
	    context->nfds = fd + 1;
	FD_SET(fd, fds);
    }
    return context;
}

MpContext mpAddReadCallback(MpContext context, int fd,
			    MpCallbackFunc callback, void *data)
{
    return mpAddCallback(context, fd, callback, data,
			 &context->readfds, &context->readcbs[fd]);
}

MpContext mpAddWriteCallback(MpContext context, int fd,
			     MpCallbackFunc callback, void *data)
{
    return mpAddCallback(context, fd, callback, data,
			 &context->writefds, &context->writecbs[fd]);
}

MpContext mpAddExceptCallback(MpContext context, int fd,
			     MpCallbackFunc callback, void *data)
{
    return mpAddCallback(context, fd, callback, data,
			 &context->exceptfds, &context->exceptcbs[fd]);
}

MpContext mpAddErrorCallback(MpContext context, int fd,
			     MpCallbackFunc callback, void *data)
{
    return mpAddCallback(context, fd, callback, data,
			 NULL, &context->errorcbs[fd]);
}

static MpContext mpRemoveCallback(MpContext context, int fd,
				  fd_set *fds, MpCallback cb)
{
    int i;

    if (fd >= FD_SETSIZE)
	return NULL /* XXX ERROR */;

    cb->callback = NULL;
    cb->data = NULL;

    if (fds) {
	FD_CLR(fd, fds);

	if (fd + 1 == context->nfds) {
	    for (i = context->nfds - 1; i >= 0; i--)
		if ((FD_ISSET(i, &context->readfds)
		     || FD_ISSET(i, &context->writefds)
		     || FD_ISSET(i, &context->exceptfds))) {
		    context->nfds = i + 1;
		    break;
		}
	    if (i == -1)
		context->nfds = 0;
	}
    }
    return context;
}

MpContext mpRemoveReadCallback(MpContext context, int fd)
{
    return mpRemoveCallback(context, fd,
			    &context->readfds, &context->readcbs[fd]);
}

MpContext mpRemoveWriteCallback(MpContext context, int fd)
{
    return mpRemoveCallback(context, fd,
			    &context->writefds, &context->writecbs[fd]);
}

MpContext mpRemoveExceptCallback(MpContext context, int fd)
{
    return mpRemoveCallback(context, fd,
			    &context->exceptfds, &context->exceptcbs[fd]);
}

MpContext mpRemoveErrorCallback(MpContext context, int fd)
{
    return mpRemoveCallback(context, fd,
			    NULL, &context->errorcbs[fd]);
}



/*
 * Signal handling
 */

static sigjmp_buf jmp_sig_env;	/* used in mpMainLoop to jump over select */
static volatile sig_atomic_t jmp_sig_env_valid; /* is `jmp_sig_env' valid? */

/* mpSignalHandler - signal handler used for all mp signals.
   Increments signal count for signal `sig' and returns, either
   normally, or, if called while in the critical section in the mp
   main loop, with a siglongjmp skipping over the select(2) call.

   XXX This has a race condition: if another signal occurs while in
   `mpSignalHandler', before the previous signal counter is increased,
   the old signal will get lost if mpSignalHandler is returned by
   siglongjmp. To avoid that, POSIX signals would be required. Perhaps
   we should switch to real POSIX signal calls, maybe with an
   emulation using 4.3BSD sigvec calls.  */

static RETSIGTYPE mpSignalHandler(int sig)
{
    MpContext context;
    MpSignalCallback cb;

    if (sig < 1 || sig >= NSIG)
	return;

    debuglog(DEBUG_DAEMON, ("mpSignalHandler: received signal %d\n", sig));

    /* Since we don't have a mp context herein, we set the signal flag
       in all contexts where a callback for that signal is installed.  */
    for (context = mp_context_head.next; context; context = context->next) {
	cb = &context->signalcbs[sig];
	if (cb->callback)
	    cb->flag++;
    }

    if (jmp_sig_env_valid) {	/* if longjmp environment is valid */
	jmp_sig_env_valid = 0;
#ifndef HAVE_SIGSETJMP
	sigUnblock(sig);	/* must unblock signal */
#endif /* !defined(HAVE_SIGSETJMP) */
	siglongjmp(jmp_sig_env, 1); /* jump back to mpMainLoop */
    }
}

/* mpPendingSignals - returns 1 if any signals the specified `context'
   are still pending, i.e. have been catched by `mpSignalHandler' but
   have not been processed by `mpCallSignalCallback'.  */
static int mpPendingSignals(MpContext context)
{
    int sig;
    MpSignalCallback cb;

    for (sig = 1; sig < NSIG; sig++) {
	cb = &context->signalcbs[sig];
	if (cb->flag)
	    return 1;
    }
    return 0;
}

static void mpCallSignalCallback(MpContext context, int sig)
{
    int blocked;
    MpSignalCallback cb;

    if (sig < 1 || sig >= NSIG)
	return;

    cb = &context->signalcbs[sig];
    if (cb->callback) {
	blocked = sigBlock(sig); /* protect against another signal */
	while (cb->flag) {
	    cb->flag--;
	    debuglog(DEBUG_DAEMON,
		     ("mpCallSignalCallback: calling callback for signal %d\n",
		      sig));
	    (*cb->callback)(sig, cb->data);
	}
	if (blocked == 0)	/* signal was not blocked before, */
	    sigUnblock(sig);	/* therefore unblock it */
    }
}

static void mpCallSignalCallbacks(MpContext context)
{
    int sig;

    for (sig = 1; sig < NSIG; sig++)
	mpCallSignalCallback(context, sig);
}

MpContext mpSetSignalCallback(MpContext context, int sig,
			      MpSignalCallbackFunc callback,
			      MpSignalCallbackFunc *ocallback,
			      void *data)
{
    int blocked;
    MpSignalCallback cb;

    if (sig < 1 || sig >= NSIG)
	return NULL;

    cb = &context->signalcbs[sig];
    if (ocallback)
	*ocallback = cb->callback;

    blocked = sigBlock(sig);	/* protect against another signal */
    if (callback == MP_SIG_DFL || callback == MP_SIG_IGN) {
	mpCallSignalCallback(context, sig); /* call pending signal handlers */
	cb->data = NULL;
	cb->flag = 0;
	sigHandler(sig, callback == MP_SIG_DFL ? SIG_DFL : SIG_IGN);
    }
    else {
	mpCallSignalCallback(context, sig); /* call pending signal handlers */
	cb->callback = callback;
	cb->data = data;
	sigHandler(sig, mpSignalHandler);
    }
    if (blocked == 0)		/* signal was not blocked before, */
	sigUnblock(sig);	/* therefore unblock it */

    return context;
}



/* mpFlush - flush the output of `fd' by explicitely calling its write
   callback.  */
MpContext mpFlush(MpContext context, int fd)
{
    MpCallback cb;

    if (fd >= FD_SETSIZE || (cb = &context->writecbs[fd])->callback == NULL)
	return NULL;

    if ((*cb->callback)(fd, cb->data) == -1) {
	mpRemoveWriteCallback(context, fd);
	cb = &context->errorcbs[fd];
	if (cb->callback)
	    (*cb->callback)(fd, cb->data);
    }

    return context;
}



/* mpMainLoop - wait for I/O event on one or more of the file
   descriptors set in the multiplex context and call the corresponding
   callback functions. If the callback function returns an error
   (return value -1), the callback for this file descriptor is
   immediately removed (to prevent busy loops) and then the error
   callback is called, if one was installed. This error callback can
   fix the error condition and perhaps re-install the action
   callback, if desired. This function loops until select() returns
   error.  */
int mpMainLoop(MpContext context)
{
    int i;
    int result = 0;
    fd_set readfds;
    fd_set writefds;
    fd_set exceptfds;
    MpCallback cb;

    while (result >= 0) {	/* while no error in select */
	FD_ZERO(&readfds);
	FD_ZERO(&writefds);
	FD_ZERO(&exceptfds);
	for (i = 0; i < context->nfds; i++) {
	    if (FD_ISSET(i, &context->readfds))
		FD_SET(i, &readfds);
	    if (FD_ISSET(i, &context->writefds))
		FD_SET(i, &writefds);
	    if (FD_ISSET(i, &context->exceptfds))
		FD_SET(i, &exceptfds);
	}

	/* Use setjmp/longjmp from signal handler to avoid race
           condition if signal occurs between `mpPendingSignals' and
           the call to select.  */
	if (sigsetjmp(jmp_sig_env, 1) == 0) {
	    jmp_sig_env_valid = 1;
	    if (!mpPendingSignals(context))
		result = select(context->nfds, &readfds, &writefds, &exceptfds,
				NULL);
	    else result = -1, errno = EINTR; /* pending signals */
	    jmp_sig_env_valid = 0;
	}
	else result = -1, errno = EINTR; /* signal occurred */

	if (result == -1 && errno == EINTR) {
	    result = 0;
	    mpCallSignalCallbacks(context);
	}
	else if (result == -1 && errno == EBADF) {
	    struct timeval zero_timeout;
	    /* One (or more) of the file descriptors in the masks is
	       illegal. Unfortunately, select(2) does not return which
	       ones are illegal. Therefore we have to check each one
	       separatly.  */
	    for (i = 0; i < context->nfds; i++) {
		FD_ZERO(&readfds);
		FD_ZERO(&writefds);
		FD_ZERO(&exceptfds);
		if (FD_ISSET(i, &context->readfds))
		    FD_SET(i, &readfds);
		if (FD_ISSET(i, &context->writefds))
		    FD_SET(i, &writefds);
		if (FD_ISSET(i, &context->exceptfds))
		    FD_SET(i, &exceptfds);
		zero_timeout.tv_sec = 0;
		zero_timeout.tv_usec = 0;
		result = select(context->nfds, &readfds, &writefds, &exceptfds,
				&zero_timeout);
		if (result == -1 && errno == EBADF) {
		    if (FD_ISSET(i, &context->exceptfds))
			mpRemoveExceptCallback(context, i);
		    if (FD_ISSET(i, &context->readfds))
			mpRemoveReadCallback(context, i);
		    if (FD_ISSET(i, &context->writefds))
			mpRemoveWriteCallback(context, i);
		    cb = &context->errorcbs[i];
		    if (cb->callback)
			(*cb->callback)(i, cb->data);
		}
	    }
	    result = 0;
	}
	else if (result > 0) {
	    for (i = 0; result > 0 && i < context->nfds; i++) {
		if (FD_ISSET(i, &exceptfds)) {
		    cb = &context->exceptcbs[i];
		    if ((*cb->callback)(i, cb->data) == -1) {
			mpRemoveExceptCallback(context, i);
			cb = &context->errorcbs[i];
			if (cb->callback)
			    (*cb->callback)(i, cb->data);
		    }
		    result--;
		}
	    }
	    for (i = 0; result > 0 && i < context->nfds; i++) {
		if (FD_ISSET(i, &readfds)) {
		    cb = &context->readcbs[i];
		    if ((*cb->callback)(i, cb->data) == -1) {
			mpRemoveReadCallback(context, i);
			cb = &context->errorcbs[i];
			if (cb->callback)
			    (*cb->callback)(i, cb->data);
		    }
		    result--;
		}
	    }
	    for (i = 0; result > 0 && i < context->nfds; i++) {
		if (FD_ISSET(i, &writefds)) {
		    cb = &context->writecbs[i];
		    if ((*cb->callback)(i, cb->data) == -1) {
			mpRemoveWriteCallback(context, i);
			cb = &context->errorcbs[i];
			if (cb->callback)
			    (*cb->callback)(i, cb->data);
		    }
		    result--;
		}
	    }
	}
    }
    return (result);
}


