/*
 * libsyncml - A syncml protocol implementation
 * Copyright (C) 2005  Armin Bauer <armin.bauer@opensync.org>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; version 
 * 2.1 of the License.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
 *
 */

#include <libsyncml/syncml.h>
#include "sml_auth.h"

#include <libsyncml/syncml_internals.h>
#include "sml_auth_internals.h"
#include <libsyncml/sml_session_internals.h>
#include <libsyncml/sml_elements_internals.h>
#include <libsyncml/sml_command_internals.h>

/**
 * @defgroup GroupIDPrivate Group Description Internals
 * @ingroup ParentGroupID
 * @brief The private part
 * 
 */
/*@{*/

void _status_callback(SmlSession *session, SmlStatus *status, void *userdata)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %p)", __func__, session, status, userdata);
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

void _header_callback(SmlSession *session, SmlHeader *header, SmlCred *cred, void *userdata)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %p, %p)", __func__, session, header, cred, userdata);
	smlAssert(session);
	smlAssert(userdata);
	SmlStatus *reply = NULL;
	SmlAuthenticator *auth = userdata;
	SmlError *error = NULL;
	char *buffer = NULL;
	unsigned int buffersize = 0;
	
	if (!cred && !auth->enabled) {
		smlTrace(TRACE_INTERNAL, "Auth is disabled and no cred given");
		auth->state = SML_NO_ERROR;
	} else if (!cred && auth->enabled && auth->state != SML_AUTH_ACCEPTED) {
		smlTrace(TRACE_INTERNAL, "Auth is required");
		auth->state = SML_ERROR_AUTH_REQUIRED;
		
		smlErrorSet(&error, SML_ERROR_AUTH_REQUIRED, "Auth required but not given");
		smlSessionDispatchEvent(session, SML_SESSION_EVENT_ERROR, NULL, NULL, NULL, error);
		smlErrorDeref(&error);
	} else if ((!cred && auth->enabled && auth->state == SML_AUTH_ACCEPTED) || \
		(cred && !auth->enabled)) {
		smlTrace(TRACE_INTERNAL, "Auth is already accepted %i", auth->enabled);
		auth->state = SML_AUTH_ACCEPTED;
	} else {
		smlTrace(TRACE_INTERNAL, "Input is \"%s\"", cred->data);
		
		switch (cred->type) {
			case SML_AUTH_TYPE_BASIC:
				
				if (!smlBase64Decode(cred->data, &buffer, &buffersize, &error))
					goto error;
				
				smlTrace(TRACE_INTERNAL, "After decode is \"%s\"", buffer);
				
				char **arr = g_strsplit(buffer, ":", 2);
				g_free(buffer);
				
				smlTrace(TRACE_INTERNAL, "Username \"%s\", Password \"%s\"", arr[0], arr[1]);
				
				if (auth->verifyCallback) {
					auth->verifyCallback(auth, arr[0], arr[1], auth->verifyCallbackUserdata, &auth->state);
				} else {
					smlTrace(TRACE_INTERNAL, "No verify callback set");
					auth->state = SML_ERROR_AUTH_REJECTED;
				}
				
				if (auth->state == SML_ERROR_AUTH_REJECTED) {
					smlErrorSet(&error, SML_ERROR_AUTH_REJECTED, "Auth rejected for username %s", arr[0]);
					smlSessionDispatchEvent(session, SML_SESSION_EVENT_ERROR, NULL, NULL, NULL, error);
					smlErrorDeref(&error);
				}
				
				g_strfreev(arr);
	
				break;
			case SML_AUTH_TYPE_MD5:
				break;
			default:
				smlErrorSet(&error, SML_ERROR_GENERIC, "Unknown auth format");
				goto error;
		}
	}
	
	if (auth->state == SML_ERROR_AUTH_REJECTED || auth->state == SML_ERROR_AUTH_REQUIRED) {
		smlTrace(TRACE_INTERNAL, "Ending session due to wrong / missing creds");
		session->end = TRUE;
	}
	
	reply = smlAuthHeaderReply(session, auth->state, &error);
	if (!reply)
		goto error;
	
	if (!smlSessionSendReply(session, reply, &error)) {
		smlStatusUnref(reply);
		goto error;
	}
	
	smlStatusUnref(reply);
	
	smlTrace(TRACE_EXIT, "%s", __func__);
	return;

error:
	smlSessionDispatchEvent(session, SML_SESSION_EVENT_ERROR, NULL, NULL, NULL, error);
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, smlErrorPrint(&error));
	smlErrorDeref(&error);
	return;
}

/*@}*/

/**
 * @defgroup GroupID Group Description
 * @ingroup ParentGroupID
 * @brief What does this group do?
 * 
 */
/*@{*/

SmlAuthenticator *smlAuthNew(SmlError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, error);
	SmlAuthenticator *auth = smlTryMalloc0(sizeof(SmlAuthenticator), error);
	if (!auth)
		goto error;

	auth->enabled = TRUE;
	auth->state = SML_ERROR_AUTH_REQUIRED;
	
	smlTrace(TRACE_EXIT, "%s: %p", __func__, auth);
	return auth;

error:
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, smlErrorPrint(error));
	return NULL;
}


void smlAuthFree(SmlAuthenticator *auth)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, auth);
	smlAssert(auth);
	
	g_free(auth);
	
	smlTrace(TRACE_EXIT, "%s", __func__);	
}

SmlBool smlAuthRegister(SmlAuthenticator *auth, SmlManager *manager, SmlError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %p)", __func__, auth, manager, error);
	smlAssert(auth);
	smlAssert(manager);
	
	smlManagerRegisterHeaderHandler(manager, _header_callback, _status_callback, auth);
	
	smlTrace(TRACE_EXIT, "%s", __func__);
	return TRUE;
}

void smlAuthSetState(SmlAuthenticator *auth, SmlErrorType type)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %i)", __func__, auth, type);
	smlAssert(auth);
	
	auth->state = type;
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

SmlStatus *smlAuthHeaderReply(SmlSession *session, SmlAuthType code, SmlError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %i, %p)", __func__, session, code, error);
	
	SmlStatus *reply = smlStatusNew(code, 0, session->lastReceivedMessageID, session->source, session->target, SML_COMMAND_TYPE_HEADER, error);
	if (!reply)
		goto error;
	
	if (code == SML_ERROR_AUTH_REJECTED || code == SML_ERROR_AUTH_REQUIRED) {
		reply->cred = smlTryMalloc0(sizeof(SmlCred), error);
		if (!reply->cred)
			goto error_free_reply;
		
		reply->cred->format = SML_FORMAT_TYPE_BASE64;
		reply->cred->type = SML_AUTH_TYPE_BASIC;
	}
	
	smlTrace(TRACE_EXIT, "%s: %p", __func__, reply);
	return reply;

error_free_reply:
	smlStatusUnref(reply);
error:
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, smlErrorPrint(error));
	return NULL;
}

void smlAuthSetVerifyCallback(SmlAuthenticator *auth, SmlAuthVerifyCb callback, void *userdata)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %p)", __func__, auth, callback, userdata);
	smlAssert(auth);
	auth->verifyCallback = callback;
	auth->verifyCallbackUserdata = userdata;
	smlTrace(TRACE_EXIT, "%s", __func__);
}

void smlAuthSetEnable(SmlAuthenticator *auth, SmlBool enabled)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %i)", __func__, auth, enabled);
	smlAssert(auth);
	
	auth->enabled = enabled;
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

SmlBool smlAuthIsEnabled(SmlAuthenticator *auth)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, auth);
	smlAssert(auth);
	
	return auth->enabled;
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

/*@}*/
