/**
 * CMI : Cluster Method Invocation
 * Copyright (C) 2007 Bull S.A.S.
 * Contact: carol@ow2.org
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA
 * --------------------------------------------------------------------------
 * $Id: SmartClassLoader.java 1547 2007-12-13 21:32:55Z loris $
 * --------------------------------------------------------------------------
 */

package org.ow2.carol.cmi.smart.client;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URLClassLoader;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.ow2.carol.cmi.smart.api.Message;
import org.ow2.carol.cmi.smart.api.Operations;
import org.ow2.carol.cmi.smart.message.ClassAnswer;
import org.ow2.carol.cmi.smart.message.ClassRequest;
import org.ow2.carol.cmi.smart.message.FactoryNameRequest;
import org.ow2.carol.cmi.smart.message.FactoryNameAnswer;
import org.ow2.carol.cmi.smart.message.ProviderURLsAnswer;
import org.ow2.carol.cmi.smart.message.ProviderURLsRequest;
import org.ow2.carol.cmi.smart.message.ResourceAnswer;
import org.ow2.carol.cmi.smart.message.ResourceRequest;

/**
 * ClassLoader that is used and that asks the remote server.
 *
 * @author The new CMI team
 * @author Florent Benoit
 */
public class SmartClassLoader extends URLClassLoader {

    /**
     * Use the JDK logger (to avoid any dependency).
     */
    private static Logger logger = Logger.getLogger(SmartClassLoader.class
            .getName());

    /**
     * Socket adress used to connect.
     */
    private InetSocketAddress socketAddress = null;

    /**
     * Default buffer size.
     */
    private static final int DEFAULT_BUFFER_SIZE = 3000;

    /**
     * Number of classes downloaded.
     */
    private static int nbClasses = 0;

    /**
     * Number of resources downloaded.
     */
    private static int nbResources = 0;

    /**
     * Number of bytes downloaded.
     */
    private static long nbBytes = 0;

    /**
     * All the time it took to ask server.
     */
    private static long timeToDownload = 0;

    /**
     * Creates a new classloader.
     *
     * @param host
     *            The given host.
     * @param port
     *            The given port number.
     */
    public SmartClassLoader(final String host, final int port) {
        super(new URL[0], Thread.currentThread().getContextClassLoader());
        socketAddress = new InetSocketAddress(host, port);

        // Add hook for shutdown
        Runtime.getRuntime().addShutdownHook(new ShutdownHook());
    }

    /**
     * Get the list of available provider urls with the given protocol.
     *
     * @param protocol
     *            The given protocol used in the client side.
     * @return The list of provider urls.
     */
    public List<String> getProviderURLs(final String protocol) {
        List<String> providerURLs = null;
        SocketChannel channel = null;
        if (protocol != null) {
            try {
                long tStart = System.currentTimeMillis();
                // Get channel
                channel = getChannel();
                ByteBuffer answerBuffer = sendRequest(new ProviderURLsRequest(
                        protocol), channel);

                // Gets opCode
                byte opCode = getOpCode(answerBuffer, channel);

                // stats
                timeToDownload = timeToDownload
                        + (System.currentTimeMillis() - tStart);

                // Switch :
                switch (opCode) {
                case Operations.PROVIDER_URLS_ANSWER:
                    ProviderURLsAnswer providerURLAnswer = new ProviderURLsAnswer(
                            answerBuffer);
                    providerURLs = providerURLAnswer.getProviderURLs();
                    break;
                default:
                    throw new IllegalStateException("Invalid opCode '" + opCode
                            + "' received");
                }
            } finally {
                // cleanup
                cleanChannel(channel);
            }
            return providerURLs;
        }
        logger.log(Level.INFO, "The protocol of client is missing!");
        return null;
    }

    /**
     * Get the factory name with the given protocol.
     *
     * @param protocol
     *            The protocol used in the client side.
     * @return The factory name.
     */
    public String getWrappedFactoryName(final String protocol) {
        String factoryName = null;
        SocketChannel channel = null;
        if (protocol != null) {
            try {
                // Get channel
                channel = getChannel();
                ByteBuffer answerBuffer = sendRequest(new FactoryNameRequest(
                        protocol), channel);
                // Gets opCode
                byte opCode = getOpCode(answerBuffer, channel);

                // Switch :
                switch (opCode) {
                case Operations.FACTORY_NAME_ANSWER:
                    FactoryNameAnswer factoryNameAnswer = new FactoryNameAnswer(
                            answerBuffer);
                    factoryName = factoryNameAnswer.getFactoryName();
                    break;
                default:
                    throw new IllegalStateException("Invalid opCode '" + opCode
                            + "' received");
                }
            } finally {
                // cleanup
                cleanChannel(channel);
            }
            return factoryName;
        }
        logger.log(Level.INFO, "The protocol of client is missing!");
        return null;
    }

    /**
     * Gets a channel to communicate with the server.
     *
     * @return a socket channel.
     */
    private SocketChannel getChannel() {
        SocketChannel channel = null;

        // open
        try {
            channel = SocketChannel.open();
        } catch (IOException e) {
            cleanChannel(channel);
            throw new IllegalStateException("Cannot open a channel", e);
        }

        // Connect
        try {
            channel.connect(socketAddress);
        } catch (IOException e) {
            cleanChannel(channel);
            throw new IllegalStateException("Cannot connect the channel", e);
        }

        return channel;
    }

    /**
     * Cleanup the channel if there was a failure.
     *
     * @param channel
     *            the channel to cleanup.
     */
    private void cleanChannel(final SocketChannel channel) {
        if (channel != null) {
            try {
                channel.close();
            } catch (IOException e) {
                logger.log(Level.FINE, "Cannot close the given channel", e);
            }
        }
    }

    /**
     * Sends the given message on the given channel.
     *
     * @param message
     *            the message to send
     * @param channel
     *            the channel used to send the message.
     * @return the bytebuffer containing the answer (to analyze)
     */
    public ByteBuffer sendRequest(final Message message,
            final SocketChannel channel) {

        // Send request
        try {
            channel.write(message.getMessage());
        } catch (IOException e) {
            cleanChannel(channel);
            throw new IllegalStateException("Cannot send the given message '"
                    + message + "'.", e);
        }

        // Read response
        ByteBuffer buffer = ByteBuffer.allocateDirect(DEFAULT_BUFFER_SIZE);
        ByteBuffer completeBuffer = null;

        try {
            int length = 0;
            boolean finished = false;
            while (((channel.read(buffer)) != -1)&&(!finished)) {
                // can read header
                if (buffer.position() >= Message.HEADER_SIZE) {
                    // Got length, create buffer
                    if (completeBuffer == null) {
                        length = buffer.getInt(1);

                        // Size + default buffer size so the copy from current
                        // buffer work all the time
                        completeBuffer = ByteBuffer
                                .allocate(Message.HEADER_SIZE + length
                                        + DEFAULT_BUFFER_SIZE);
                    }
                }
                // Append all read data into completeBuffer
                buffer.flip();
                completeBuffer.put(buffer);
                // clear for next time
                buffer.clear();

                if (completeBuffer.position() >= Message.HEADER_SIZE + length) {
                    completeBuffer.limit(Message.HEADER_SIZE + length);
                    // Skip Header, got OpCode, now create function
                    completeBuffer.position(Message.HEADER_SIZE);
                    finished = true;
                    break;
                }
            }
        } catch (Exception e) {
            cleanChannel(channel);
            throw new IllegalStateException(
                    "Cannot read the answer from the server.", e);
        }


        return completeBuffer;

    }

    /**
     * Gets the operation code from the current buffer.
     *
     * @param buffer
     *            the buffer to analyze.
     * @param channel
     *            the channel which is use for the exchange.
     * @return the operation code.
     */
    private byte getOpCode(final ByteBuffer buffer, final SocketChannel channel) {
        // Get operation asked by client
        byte opCode = buffer.get(0);
        // Length
        int length = buffer.getInt(1);
        if (length < 0) {
            cleanChannel(channel);
            throw new IllegalStateException("Invalid length for client '"
                    + length + "'.");
        }
        return opCode;
    }

    /**
     * Finds and loads the class with the specified name from the URL search
     * path.<br>
     * If the super classloader doesn't find the class, it ask the remote server
     * to download the class
     *
     * @param name
     *            the name of the class
     * @return the resulting class
     * @exception ClassNotFoundException
     *                if the class could not be found
     */
    @Override
    protected synchronized Class<?> findClass(final String name)
            throws ClassNotFoundException {
        // search super classloader
        Class<?> clazz = null;
        logger.log(Level.INFO, "Try to find the class "+"\""+name+"\""+"...");
//        try {
//            super.findClass(name);
//        } catch (ClassNotFoundException cnfe) {
            SocketChannel channel = null;
            try {
                long tStart = System.currentTimeMillis();
                // Get channel
                channel = getChannel();
                ByteBuffer answerBuffer = sendRequest(new ClassRequest(name),
                        channel);
                // Gets opCode
                byte opCode = getOpCode(answerBuffer, channel);

                // stats
                timeToDownload = timeToDownload
                        + (System.currentTimeMillis() - tStart);

                // Switch :
                switch (opCode) {
                case Operations.CLASS_ANSWER:
                    ClassAnswer classAnswer = new ClassAnswer(answerBuffer);
                    try {
                        clazz = loadClass(name, classAnswer.getByteCode());
                    } catch (IOException e) {
                        throw new ClassNotFoundException(
                                "Cannot find the class", e);
                    }
                    nbClasses++;
                    nbBytes = nbBytes + classAnswer.getByteCode().length;
                    break;
                case Operations.CLASS_NOT_FOUND:
                    throw new ClassNotFoundException("The class '" + name
                            + "' was not found on the remote side");
                default:
                    throw new ClassNotFoundException("Invalid opCode '"
                            + opCode + "' received");
                }
            } finally {
                // cleanup
                cleanChannel(channel);
            }
//        }

        return clazz;

    }

    /**
     * Defines a class by loading the bytecode for the given class name.
     *
     * @param className
     *            the name of the class to define
     * @param bytecode
     *            the bytecode of the class
     * @return the class that was defined
     * @throws IOException
     *             if the class cannot be defined.
     */
    private Class loadClass(final String className, final byte[] bytecode)
            throws IOException {
        Class clazz = null;
        try {
            ClassLoader loader = this;
            Class cls = Class.forName("java.lang.ClassLoader");
            java.lang.reflect.Method method = cls.getDeclaredMethod(
                    "defineClass", new Class[] {String.class, byte[].class,
                            int.class, int.class });

            // protected method invocaton
            method.setAccessible(true);
            try {
                Object[] args = new Object[] {className, bytecode,
                        new Integer(0), new Integer(bytecode.length) };
                clazz = (Class) method.invoke(loader, args);
            } finally {
                method.setAccessible(false);
            }
        } catch (Exception e) {
            IOException ioe = new IOException("Cannt define class with name '"
                    + className + "'.");
            ioe.initCause(e);
            throw ioe;
        }
        return clazz;
    }

    /**
     * Finds the resource with the specified name on the URL search path. <br>
     * If resource is not found locally, search on the remote side.
     *
     * @param name
     *            the name of the resource
     * @return a <code>URL</code> for the resource, or <code>null</code> if
     *         the resource could not be found.
     */
    @Override
    public synchronized URL findResource(final String name) {
        URL url = null;
//        url = super.findResource(name);
//
//        if (url != null) {
//            return url;
//        }

        if (name.startsWith("META-INF")) {
            return null;
        }

        SocketChannel channel = null;
        try {
            long tStart = System.currentTimeMillis();

            // Get channel
            channel = getChannel();
            ByteBuffer answerBuffer = sendRequest(new ResourceRequest(name),
                    channel);

            // Gets opCode
            byte opCode = getOpCode(answerBuffer, channel);

            // stats
            timeToDownload = timeToDownload
                    + (System.currentTimeMillis() - tStart);

            // Switch :
            switch (opCode) {
            case Operations.RESOURCE_ANSWER:
                ResourceAnswer resourceAnswer = new ResourceAnswer(answerBuffer);
                String resourceName = resourceAnswer.getResourceName();
                byte[] bytes = resourceAnswer.getBytes();

                nbResources++;
                nbBytes = nbBytes + resourceAnswer.getBytes().length;

                File fConfDir = new File(System.getProperty("java.io.tmpdir")
                        + File.separator + "cmi-smart");
                if (!fConfDir.exists()) {
                    fConfDir.mkdir();
                }

                // convert / into File.separator
                String[] tokens = resourceName.split("/");
                StringBuilder sb = new StringBuilder();
                for (String token : tokens) {
                    if (sb.length() > 0) {
                        sb.append(File.separator);
                    }
                    sb.append(token);
                }

                // Create parent dir if does not exist
                File urlFile = new File(fConfDir, sb.toString());
                if (!urlFile.getParentFile().exists()) {
                    urlFile.getParentFile().mkdir();
                }

                // dump stream
                FileOutputStream fos = new FileOutputStream(urlFile);
                fos.write(bytes);
                fos.close();
                url = urlFile.toURI().toURL();
                break;
            case Operations.RESOURCE_NOT_FOUND:
                url = null;
                break;
            default:
                throw new IllegalStateException("Invalid opCode '" + opCode
                        + "' received");
            }
        } catch (Exception e) {
            logger.log(Level.SEVERE, "Cannot handle : findResource '" + name
                    + "'", e);
        } finally {
            // cleanup
            cleanChannel(channel);
        }
        return url;
    }

    /**
     * Hook that is called when process is going to shutdown.
     *
     * @author Florent Benoit
     */
    static class ShutdownHook extends Thread {

        /**
         * Display stats.
         */
        @Override
        public void run() {
            // display statistics (use sysout)
            System.out.println("Downloaded '" + nbClasses + "' classes, '"
                    + nbResources + "' resources for a total of '" + nbBytes
                    + "' bytes and it took '" + timeToDownload + "' ms.");
        }
    }
}
