/*
 * Copyright (c) 2001-2006, John Mettraux, OpenWFE.org
 * All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without 
 * modification, are permitted provided that the following conditions are met:
 * 
 * . Redistributions of source code must retain the above copyright notice, this
 *   list of conditions and the following disclaimer.  
 * 
 * . Redistributions in binary form must reproduce the above copyright notice, 
 *   this list of conditions and the following disclaimer in the documentation 
 *   and/or other materials provided with the distribution.
 * 
 * . Neither the name of the "OpenWFE" nor the names of its contributors may be
 *   used to endorse or promote products derived from this software without
 *   specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
 * POSSIBILITY OF SUCH DAMAGE.
 *
 * $Id: RestService.java 2713 2006-06-01 14:38:45Z jmettraux $
 */

//
// RestService.java
//
// john.mettraux@openwfe.org
//
// generated with 
// jtmpl 1.0.04 31.10.2002 John Mettraux (jmettraux@openwfe.org)
//

package openwfe.org.rest;

import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.channels.ReadableByteChannel;

import openwfe.org.Utils;
import openwfe.org.MapUtils;
import openwfe.org.Application;
import openwfe.org.ServiceException;
import openwfe.org.ApplicationContext;
import openwfe.org.xml.XmlUtils;
import openwfe.org.net.NetUtils;
import openwfe.org.net.SocketService;
import openwfe.org.net.ChannelInputStream;
import openwfe.org.time.Time;
import openwfe.org.misc.IoUtils;


/**
 * http listening
 * (London Calling ?)
 *
 * <p><font size=2>CVS Info :
 * <br>$Author: jmettraux $
 * <br>$Date: 2006-06-01 16:38:45 +0200 (Thu, 01 Jun 2006) $
 * <br>$Id: RestService.java 2713 2006-06-01 14:38:45Z jmettraux $ </font>
 *
 * @author john.mettraux@openwfe.org
 */
public class RestService

    extends SocketService

{

    private final static org.apache.log4j.Logger log = org.apache.log4j.Logger
        .getLogger(RestService.class.getName());

    //
    // CONSTANTS & co

    public final static int DEFAULT_PORT 
        = 6080;

    /** 
     * Use this parameter name ('serverName') to tell the service what
     * is the server name to be advertised to the client.
     */
    public final static String P_SERVER_NAME
        = "serverName";

    /**
     * Use this parameter ('restSessionClass') to tell the service which
     * RestSession implementation it should use.
     */
    public final static String P_REST_SESSION_CLASS
        = "restSessionClass";

    /**
     * Use this parameter name ('purgeFrequency') to tell the RestService
     * how frequently it should check for timed out sessions.
     */
    public final static String P_PURGE_FREQUENCY
        = "purgeFrequency";

    /**
     * Use this parameter name ('sessionTimeout') to tell the RestService
     * what is the max idle time for a REST session.
     */
    public final static String P_SESSION_TIMEOUT
        = "sessionTimeout";

    protected final static String DEFAULT_SERVER_NAME
        = "$Id: RestService.java 2713 2006-06-01 14:38:45Z jmettraux $";

    protected final static long DEFAULT_PURGE_FREQUENCY
        = Time.parseTimeString("2m");

    protected final static long DEFAULT_SESSION_TIMEOUT
        = Time.parseTimeString("10m");


    //private final static String K_HEADER_LENGTH
    //    = "__header_length";

    private final static String K_CONTENT_LENGTH
        = "Content-length";

    //
    // FIELDS

    private java.util.Map unmodifiableServiceParams = null;

    private String serverName = null;

    private String restSessionClassName = null;

    private java.util.Map sessions = new java.util.HashMap();

    private java.util.TimerTask timeOutTask = null;

    private Long lastGivenId = new Long(-1);

    //
    // CONSTRUCTORS

    public void init 
        (final String serviceName, 
         final ApplicationContext context, 
         final java.util.Map serviceParams)
    throws 
        ServiceException
    {
        setDefaultPort(DEFAULT_PORT); // else it will listen on 7000...

        super.init(serviceName, context, serviceParams);

        this.unmodifiableServiceParams = java.util.Collections
            .unmodifiableMap(serviceParams);

        //
        // determining serverName
        
        this.serverName = MapUtils.getAsString
            (serviceParams, P_SERVER_NAME, DEFAULT_SERVER_NAME);

        //
        // determining RestSession implementation to use
        
        this.restSessionClassName = MapUtils.getAsString
            (serviceParams, P_REST_SESSION_CLASS);

        try
        {
            final Class c = Class.forName(this.restSessionClassName);
            final RestSession tmpSession = (RestSession)c.newInstance();

            //
            // the tmpSession establishment is performed once at initialization
            // of the RestService. If succesfull, it proves that the 
            // wrapped around RestSession is operational and thus that the
            // RestService is operational.
            //
        }
        catch (final Throwable t)
        {
            throw new ServiceException
                ("Cannot use RestSession of class "+this.restSessionClassName, 
                 t);
        }

        //
        // starting session timeout system
        
        long timeout = MapUtils.getAsLong
            (serviceParams, P_SESSION_TIMEOUT, DEFAULT_SESSION_TIMEOUT);
        long frequency = MapUtils.getAsLong
            (serviceParams, P_PURGE_FREQUENCY, DEFAULT_PURGE_FREQUENCY);

        log.info("Session timeout set to "+timeout+" ms");
        log.info("Purge frequency set to "+frequency+" ms");

        final long finalTimeout = timeout;

        this.timeOutTask = new java.util.TimerTask()
        {
            public void run ()
            {
                timeOutSessions(finalTimeout);
            }
        };

        Application.getTimer().schedule(this.timeOutTask, 15, frequency);

        //
        // done

        log.info("Service '"+getName()+"' ready.");
    }

    //
    // METHODS

    public String getServerName ()
    {
        return this.serverName;
    }

    public String getRestSessionClassName ()
    {
        return this.restSessionClassName;
    }

    //
    // METHODS from Service

    public org.jdom.Element getStatus ()
    {
        org.jdom.Element result = new org.jdom.Element(getName());

        result.addContent(XmlUtils.getClassElt(this));
        result.addContent(XmlUtils.getRevisionElt("$Id: RestService.java 2713 2006-06-01 14:38:45Z jmettraux $"));

        return result;
    }

    //
    // METHODS from SocketService

    private String[] extractHeaders (final SocketChannel channel)
        throws java.io.IOException
    {
        final java.util.List l = new java.util.ArrayList(10);
        
        StringBuffer line = new StringBuffer();

        //int count = 0;
        int zeroReads = 0;
        while (true)
        {
            final int i = IoUtils.read(channel);

            if (i == -1) break;

            if (i == 0)
            {
                if (zeroReads >= 3)
                {
                    l.add(line.toString());
                    //log.debug("extractHeaders() zero-final line >"+line+"<");
                    break;
                }

                zeroReads++;

                Thread.yield();

                continue;
            }

            //count++;

            zeroReads = 0;

            final char c = (char)i;

            // debugging
            //String sc = ""+c;
            //if (c == '\r') sc="\\r";
            //if (c == '\n') sc="\\n";
            //log.debug("extractHeaders() read |"+sc+"| ("+i+")");

            if (c == '\r')
            {
                continue;
            }

            if (c == '\n')
            {
                String sLine = line.toString().trim();

                if (sLine.length() < 1) 
                {
                    //log.debug("extractHeaders() found empty line. EOH.");
                    break;
                }

                l.add(sLine);

                //log.debug("extractHeaders() added line >"+sLine+"<");

                line = new StringBuffer();

                continue;
            }

            line.append(c);
        }

        //l.add(K_HEADER_LENGTH+": "+count);

        return Utils.toStringArray(l);
    }

    /**
     * Handles a client connection.
     */
    public void handle (final SelectionKey key)
        throws ServiceException
    {
        try
        {
            final SocketChannel channel = (SocketChannel)key.channel();

            if (log.isDebugEnabled())
            {
                log.debug
                    ("handle() incoming connection from "+
                     channel.socket().getInetAddress());
            }

            final String[] headers = extractHeaders(channel);
            final String firstLine = headers[0];

            if ( ! verify(key, firstLine)) return;

            final Long sessionId = extractSessionId(firstLine);

            if (sessionId == null)
                //
                // authenticate
            {
                authenticate(key, headers);
                return;
            }

            //
            // is the session registered ?
        
            final RestSession session = (RestSession)this.sessions
                .get(sessionId);

            //log.debug
            //    ("Sessions : "+
            //     RestSession.printSessions(this.sessions));

            if (session == null)
                //
                // no
            {
                NetUtils.httpReply
                    (key,
                     404,
                     "No such session",
                     this.serverName,
                     null,
                     "text/plain",
                     null);
                return;
            }

            //
            // yes

            session.handle(key, headers);
        }
        catch (final Throwable t)
        {
            throw new ServiceException
                ("Socket handling failed", t);
        }
    }

    public void stop ()
        throws ServiceException
    {
        this.timeOutTask.cancel();
        log.info("timeout system for rest sessions stopped.");

        super.stop();
    }

    //
    // METHODS

    private void authenticate 
        (final SelectionKey key, final String[] headers)
    {

        Throwable throwable = null;
            // perhaps some exception will be lodged here. Hope none...

        //
        // is it an authorization request
        // or whatever else ?
        
        String sAuth = extractHeaderValue(headers, "Authorization");

        if (sAuth != null)
        {
            if ( ! sAuth.toLowerCase().startsWith("basic"))
            {
                NetUtils.httpReply
                    (key,
                     401,
                     "Unauthorized: only 'BASIC' authentication supported",
                     this.serverName,
                     null,
                     "text/plain",
                     null);
                return;
            }

            int i = sAuth.indexOf(" ");
            sAuth = sAuth.substring(i+1);

            //
            // base64 decoding (courtesy of Apache http://www.apache.org)

            //log.debug("*** comment me ! *** sAuth is >"+sAuth+"<");

            sAuth = 
                new String(openwfe.org.misc.Base64.decode(sAuth.getBytes()));

            //log.debug("*** comment me ! *** sAuth is >"+sAuth+"<");

            //
            // username:password

            String[] ss = sAuth.split(":");

            //
            // set up work session

            try
            {
                final Long sessionId = generateNewSessionId();

                if (log.isDebugEnabled())
                    log.debug("authenticate() new sessionId : "+sessionId);

                RestSession restSession = 
                    newRestSession(sessionId, ss[0], ss[1]);
                ss = null;

                this.sessions.put(sessionId, restSession);

                //
                // reply with session id

                org.jdom.Element eSession = new org.jdom.Element("session");
                eSession.setAttribute("id", ""+sessionId);

                log.debug("authenticate() 200 Authorized");

                NetUtils.httpReply
                    (key,
                     200,
                     "OK",
                     this.serverName,
                     null,
                     "text/xml",
                     eSession);
                return;
            }
            catch (final Throwable t)
            {
                throwable = t;

                log.debug("authenticate() failed.", t);

                // fall to "401 Unauthorized"
            }
        }

        log.debug("authenticate() 401 Unauthorized");

        NetUtils.httpReply
            (key,
             401,
             "Unauthorized",
             this.serverName,
             new String[] 
             { 
                 "WWW-Authenticate: BASIC realm=\""+this.serverName+"\"" 
             },
             "text/plain",
             throwable);
    }

    protected RestSession newRestSession
        (final Long sessionId, final String username, final String password)
    throws
        Exception
    {
        final Class clazz = Class.forName(this.restSessionClassName);
        final RestSession session = (RestSession)clazz.newInstance();

        session.init
            (this,
             sessionId, 
             username, 
             password);

        log.debug("newRestSession() ok.");

        return session;
    }

    private Long generateNewSessionId ()
    {
        synchronized (this.lastGivenId)
        {
            long id = System.currentTimeMillis();

            while (id <= this.lastGivenId.longValue()) id++;

            this.lastGivenId = new Long(id);

            return this.lastGivenId;
        }
    }

    private void timeOutSessions (long maxDelta)
    {
        log.debug("timeOutSessions() session cleaner waking up");

        long now = System.currentTimeMillis();

        synchronized (this.sessions)
        {
            java.util.List sessionsToRemove = 
                new java.util.ArrayList(this.sessions.size());
            
            java.util.Iterator it = this.sessions.keySet().iterator();
            while (it.hasNext())
            {
                Long sessionId = (Long)it.next();

                RestSession session = 
                    (RestSession)this.sessions.get(sessionId);

                final long delta = now - session.getLastUsed();

                if (delta > maxDelta) 
                {
                    if (log.isDebugEnabled())
                    {
                        log.debug
                            ("timeOutSessions() Removing session "+sessionId);
                    }

                    sessionsToRemove.add(sessionId);
                }
            }

            //
            // removing

            it = sessionsToRemove.iterator();
            while (it.hasNext()) this.sessions.remove(it.next());
        }
    }

    protected void removeSession (Long sessionId)
    {
        synchronized (this.sessions)
        {
            this.sessions.remove(sessionId);
        }
    }

    //
    // STATIC METHODS

    /**
     * Extracts the session id out of the first request line.
     */
    protected static Long extractSessionId (final String firstLine)
        throws java.io.IOException
    {
        try
        {
            final String value = RestUtils.extractFromLine
                (firstLine, "session");

            if (log.isDebugEnabled())
                log.debug("extractSessionId() sessionId = \""+value+"\"");

            return new Long(Long.parseLong(value));
        }
        catch (Exception e)
        {
            //log.debug("extractSessionId() Failed to extract sessionId", e);

            return null;
        }
    }

    /**
     * Given a string array of HTTP (post) headers, returns the value
     * for a given key.
     */
    public static String extractHeaderValue 
        (final String[] headers, final String key)
    {
        //log.debug("extractHeaderValue() looking for '"+key+"'");

        final String sKey = key.toLowerCase() + ": ";

        for (int i=0; i<headers.length; i++)
        {
            String line = headers[i];

            //log.debug("extractHeaderValue() examinining >"+line+"<");

            if (line.toLowerCase().startsWith(sKey))
                return line.substring(key.length()+2);
        }

        if (log.isDebugEnabled())
        {
            log.debug
                ("extractHeaderValue() didn't find value for key '"+key+"'");
        }

        return null;
    }

    private static String print (final String[] ss)
    {
        StringBuffer sb = new StringBuffer();

        sb.append("[");
        for (int i=0; i<ss.length; i++)
        {
            sb.append("'");
            sb.append(ss[i]);
            sb.append("'");

            if (i < ss.length-1) sb.append(", ");
        }
        sb.append("]");

        return sb.toString();
    }

    /*
     * verify first line
     */
    private boolean verify 
        (final SelectionKey key, final String firstLine)
    {
        final String[] ss = firstLine.split(" ");

        if (log.isDebugEnabled())
            log.debug("verify() ss is "+print(ss));

        if (ss.length < 3)
        {
            NetUtils.httpReply
                (key,
                 400,
                 "Bad request",
                 this.serverName,
                 null,
                 "text/plain",
                 "Incomplete HTTP request");
            return false;
        }

        //
        // check method

        String method = ss[0].toUpperCase();
        if ( ! method.equals("GET") && ! method.equals("POST"))
        {
            NetUtils.httpReply
                (key,
                 405,
                 "Method Not Allowed",
                 this.serverName,
                 null,
                 "text/plain",
                 "Method '"+method+"' not allowed.");
            return false;
        }

        //
        // check URI
        
        if ( ! ss[1].startsWith("/"+this.serverName))
        {
            NetUtils.httpReply
                (key,
                 404,
                 "Not Found",
                 this.serverName,
                 null,
                 "text/plain",
                 "Object >"+ss[1]+"< not found.");
            return false;
        }

        //
        // check protocol
        
        if ( ! ss[2].startsWith("HTTP/"))
        {
            NetUtils.httpReply
                (key,
                 505,
                 "HTTP Version not supported",
                 this.serverName,
                 null,
                 "text/plain",
                 "HTTP Version >"+ss[2]+"< not supported.");
            return false;
        }

        //
        // any other check ?

        return true;
    }

    /**
     * Extract a numeric (int) value from the HTTP headers.
     */
    public static int extractNumericHeaderValue 
        (final String[] headers, final String key)
    {
        final String s = extractHeaderValue(headers, key);

        try
        {
            return Integer.parseInt(s);
        }
        catch (final Throwable t)
        {
            // ignore
        }

        return -1;
    }

    /**
     * Determines how many bytes are still to be read.
     */
    public static int determineBytesToRead (final String[] headers)
    {
        /*
        final int headerLength = 
            extractNumericHeaderValue(headers, K_HEADER_LENGTH);
        final int contentLength =
            extractNumericHeaderValue(headers, K_CONTENT_LENGTH);

        log.debug
            ("determineBytesToRead() hl : "+headerLength+
             "  cl : "+contentLength+
             "     delta = "+(contentLength - headerLength));

        return contentLength - headerLength;
        */

        return extractNumericHeaderValue(headers, K_CONTENT_LENGTH);
    }

}
