/* -*- mode: Java; c-basic-offset: 4; indent-tabs-mode: nil; -*-  //------100-columns-wide------>|*/
/* Copyright (c) 2005 Extreme! Lab, Indiana University. All rights reserved.
 * This software is open source. See the bottom of this file for the license.
 * $Id: Transport.java,v 1.23 2007/03/14 21:06:25 aslom Exp $ */
package org.gpel.client.http.apache_http_client;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.net.InetAddress;
import java.net.Socket;
import java.net.URI;
import java.net.UnknownHostException;
import java.security.InvalidKeyException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.Principal;
import java.security.PrivateKey;
import java.security.SignatureException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateExpiredException;
import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.List;
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509ExtendedKeyManager;
import javax.net.ssl.X509TrustManager;
import org.apache.commons.httpclient.ConnectTimeoutException;
import org.apache.commons.httpclient.Credentials;
import org.apache.commons.httpclient.HostConfiguration;
import org.apache.commons.httpclient.HttpClient;
import org.apache.commons.httpclient.HttpConnection;
import org.apache.commons.httpclient.HttpConnectionManager;
import org.apache.commons.httpclient.HttpMethod;
import org.apache.commons.httpclient.HttpStatus;
import org.apache.commons.httpclient.SimpleHttpConnectionManager;
import org.apache.commons.httpclient.UsernamePasswordCredentials;
import org.apache.commons.httpclient.auth.AuthScope;
import org.apache.commons.httpclient.methods.EntityEnclosingMethod;
import org.apache.commons.httpclient.methods.GetMethod;
import org.apache.commons.httpclient.methods.InputStreamRequestEntity;
import org.apache.commons.httpclient.methods.PostMethod;
import org.apache.commons.httpclient.methods.PutMethod;
import org.apache.commons.httpclient.params.HttpConnectionParams;
import org.apache.commons.httpclient.protocol.ControllerThreadSocketFactory;
import org.apache.commons.httpclient.protocol.Protocol;
import org.apache.commons.httpclient.protocol.SecureProtocolSocketFactory;
import org.gpel.GpelConstants;
import org.gpel.GpelVersion;
import org.gpel.client.GcException;
import org.gpel.client.GcResourceNotFoundException;
import org.gpel.client.GcUtil;
import org.gpel.client.GcWebResourceType;
import org.gpel.client.GpelUserCredentials;
import org.gpel.client.http.GcHttpException;
import org.gpel.client.http.GcHttpRequest;
import org.gpel.client.http.GcHttpResponse;
import org.gpel.client.http.GcHttpTransport;
//import org.gpel.client.security.GpelUserX509Credential;
import org.gpel.logger.GLogger;
import org.xmlpull.infoset.XmlElement;
import org.xmlpull.infoset.XmlInfosetBuilder;

/**
 * Implementation of GPEL Client transports that uses Jakarta Apache HTTP Client library.
 */
public class Transport implements GcHttpTransport {
    private final static XmlInfosetBuilder builder = GpelConstants.BUILDER;

    private final static GLogger logger = GLogger.getLogger();

    private HttpClient secureClient;
    private HttpClient unsecureClient;

    private GpelUserCredentials userCredentials;

//    private GpelUserX509Credential x509credentials;

    // public void get(URI location) {
    // }

    public Transport() {
        unsecureClient = new HttpClient();
    }

    public Transport(GpelUserCredentials credentials) {
        this();
        if (credentials == null) throw new IllegalArgumentException();
        this.userCredentials = credentials;
    }

//    public Transport(GpelUserX509Credential x509credentials) {
//        if (x509credentials == null) throw new IllegalArgumentException();
//        this.x509credentials = x509credentials;
//
//        final PrivateKey privateKey = x509credentials.getUserPrivateKey();
//        final X509Certificate[] certChain = x509credentials.getUserCertChain();
//        KeyManager[] kms = null;
//        if(privateKey != null) {
//            kms = new KeyManager[] { new OneChainKeyManager(certChain, privateKey) };
//        }
//        final X509Certificate[] trustedCerts = x509credentials.getCertificatesTrustedByUser();
//        // create trust managers with trusted certificates ...
//        TrustManager[] trustCerts;
//        if(trustedCerts == null) {
//            trustCerts = new TrustManager[] { new TrustAllX509TrustManager() };
//        } else {
//            trustCerts = new TrustManager[] { new TrustedListX509TrustManager(trustedCerts) };
//        }
//        SecureProtocolSocketFactory socketFactory;
//        try {
//            socketFactory = new LimitedTrustSSLProtocolSocketFactory(kms, trustCerts);
//        } catch (KeyManagementException e) {
//            throw new GcHttpException("failed to create secure transport ", e);
//        } catch (NoSuchAlgorithmException e) {
//            throw new GcHttpException("failed to create secure transport ", e);
//        }
//        // Protocol.registerProtocol("https", new Protocol("https", socketFactory, 443));
//        // NOTE: install protocol that is local to this.client and do not affect other client
//        // this allows to have different X509 credentials in different gpel clients ...
//        final Protocol localSecureProtocol = new Protocol("https", socketFactory, 443);
//        HttpConnectionManager unsecureConnMgr = new SimpleHttpConnectionManager();
//        HttpConnectionManager secureConnMgr = new SimpleHttpConnectionManager() {
//            @Override
//            public HttpConnection getConnectionWithTimeout(
//                    final HostConfiguration hostConfiguration, final long timeout) {
//                HttpConnection conn = super.getConnectionWithTimeout(hostConfiguration, timeout);
//                conn.setProtocol(localSecureProtocol);
//                return conn;
//            }
//        };
//        this.secureClient = new HttpClient(secureConnMgr);
//        this.unsecureClient = new HttpClient(unsecureConnMgr);
//        this.secureClient.getParams().setParameter("http.socket.timeout", new Integer(4*1000));
//        this.secureClient.getParams().setParameter("http.protocol.content-charset", "UTF-8");
//        this.unsecureClient.getParams().setParameter("http.socket.timeout", new Integer(4*1000));
//        this.unsecureClient.getParams().setParameter("http.protocol.content-charset", "UTF-8");
//    }

    public XmlElement getXml(URI location) throws GcException {
        return getXml(location, true);
    }

    public HttpClient getClientForLocation(URI location) throws GcHttpException {
        String scheme = location.getScheme();
        if(secureClient != null && "https".equals(scheme)) {
            return secureClient;
        } else {
            return unsecureClient;
        }
    }

    public XmlElement getXml(URI location, boolean withBasicAuthz) throws GcHttpException {
        HttpMethod get = new GetMethod(location.toString());
        get.setRequestHeader("User-Agent", GpelVersion.getUserAgent());
        try {
            HttpClient client = getClientForLocation(location);
            if (withBasicAuthz) {
                requireCredentials(client, get, location);
            }
            client.executeMethod(get);
            String responseBody = get.getResponseBodyAsString();
            // System.err.println(responseBody);
            XmlElement el = builder.parseFragmentFromString(responseBody);
            return el;
        } catch (IOException e) {
            throw new GcException("HTTP transport get XML failed for " + location + " (withBasicAuthz="
                    + withBasicAuthz + ")", e);
        } finally {
            get.releaseConnection();
        }
    }

    // public GcHttpResponse postXml(URI location, XmlElement el) throws GcException {
    // // TODO
    // return null;
    // }

    public GcHttpResponse perform(GcHttpRequest req) throws GcHttpException {
        HttpMethod method;
        try {
            method = createHttpMethodFor(req);
        } catch (IllegalStateException e) {
            throw new GcHttpException("failed to create HTTP method for " + req, e);
        }
        try {
            HttpClient client = getClientForLocation(req.getLocation());
            if (req.useAuthz()) {
                requireCredentials(client, method, req.getLocation());
            }
            client.executeMethod(method);
            return extractResponse(req, method);
        } catch (IOException e) {
            throw new GcHttpException("HTTP transport failed accessing " + req.getLocation(), e);
        } finally {
            if (method != null) method.releaseConnection();
        }

    }

    private HttpMethod createHttpMethodFor(GcHttpRequest req) throws GcHttpException {
        String location = req.getLocation().toString();
        GcHttpRequest.Method met = req.getMethod();
        // boolean needsContent = false;
        HttpMethod method;
        if (GcHttpRequest.Method.GET.equals(met)) {
            method = new GetMethod(location);
        } else if (GcHttpRequest.Method.PUT.equals(req.getMethod())) {
            method = new PutMethod(location);
            // needsContent = true;
        } else if (GcHttpRequest.Method.POST.equals(req.getMethod())) {
            method = new PostMethod(location);
            // needsContent = true;
        } else {
            throw new GcHttpException("unknown method " + met);
        }

        method.setRequestHeader("User-Agent", GpelVersion.getUserAgent());

        if (req.hasContent()) {
            long contentLength = -1;
            InputStream input;
            // TODO more efficient
            if (req.getXmlContent() != null) {
                String string = builder.serializeToString(req.getXmlContent());
                // System.err.println(getClass()+" string="+string);
                // XmlElement el2 = builder.parseFragmentFromString(string);
                // String string2 = builder.serializeToString(el2);
                // System.err.println(getClass()+" string2="+string2);
                // XmlElement el3 = builder.parseFragmentFromString(string2);
                // String string3 = builder.serializeToString(el2);
                // System.err.println(getClass()+" string3="+string3);

                byte[] binary;
                try {
                    // TODO more efficient?
                    binary = string.getBytes("UTF8");
                } catch (UnsupportedEncodingException e) {
                    throw new GcHttpException("could not get text content as UTF8", e);
                }
                contentLength = binary.length;
                input = new ByteArrayInputStream(binary);
            } else if (req.getTextContent() != null) {
                byte[] binary;
                try {
                    // TODO more efficient?
                    binary = req.getTextContent().getBytes("utf-8");
                } catch (UnsupportedEncodingException e) {
                    throw new GcHttpException("could not get text content as UTF8", e);
                }
                contentLength = binary.length;
                input = new ByteArrayInputStream(binary);
            } else if (req.getBinaryContent() != null) {
                byte[] binary = req.getBinaryContent();
                contentLength = binary.length;
                input = new ByteArrayInputStream(binary);
            } else {
                throw new IllegalStateException("content missing");
            }

            EntityEnclosingMethod emet = (EntityEnclosingMethod) method;
            emet.setRequestEntity(new InputStreamRequestEntity(input, contentLength, req
                    .getContentType()));
        }

        return method;
    }

    private void requireCredentials(HttpClient client, HttpMethod method, URI locationUri) {
        if (userCredentials != null) {
            client.getParams().setAuthenticationPreemptive(true);
            method.setDoAuthentication(true);
            Credentials credentials = new UsernamePasswordCredentials(userCredentials
                    .getUserName(), userCredentials.getUserPassword());
            client.getState()
                    .setCredentials(new AuthScope(locationUri.getHost(), locationUri.getPort()),
                            credentials);
        }
    }

    private GcHttpResponse extractResponse(GcHttpRequest req, HttpMethod method)
            throws GcHttpException {
        try {
            int statusCode = method.getStatusCode();

            if (!(statusCode == HttpStatus.SC_OK || statusCode == HttpStatus.SC_CREATED || statusCode == HttpStatus.SC_ACCEPTED)) {
                String responseBody = method.getResponseBodyAsString();
                String extraInfo = "";
                if (responseBody != null) {
                    extraInfo = "\n" + responseBody;
                }
                if (statusCode == HttpStatus.SC_NOT_FOUND) {
                    throw new GcResourceNotFoundException("could not get resource from "
                            + req.getLocation() + extraInfo, req.getLocation());
                } else {
                    throw new GcHttpException("HTTP failed with status " + statusCode + " ("
                            + method.getStatusLine() + ")" + extraInfo);
                }
            }

            URI location = URI.create(method.getURI().toString());
            if (method.getResponseHeader("Location") != null) {
                String locationRedirect = method.getResponseHeader("Location").getValue();
                location = URI.create(locationRedirect);
            }
            GcHttpResponse res;
            String contentType = null;
            if (method.getResponseHeader("Content-type") != null) {
                contentType = method.getResponseHeader("Content-type").getValue();
                GcWebResourceType cat = GcUtil.categorizeContentType(contentType);
                if (cat == GcWebResourceType.XML) {
                    byte[] responseBody = method.getResponseBody();
                    logger.finest("responseBody=" + new String(responseBody, "UTF8"));
                    XmlElement el = builder.parseFragmentFromInputStream(new ByteArrayInputStream(
                            responseBody));
                    res = new GcHttpResponse(location, contentType, el);
                } else {
                    throw new IllegalStateException("unsupported " + cat);
                }
            } else {
                res = new GcHttpResponse(location);
            }

            return res;
        } catch (IOException e) {
            throw new GcHttpException("HTTP transport response processing failed", e);
        }
        // XmlElement el = builder.parseFragmentFromString(responseBody);
        // return el;

    }

    private static class LimitedTrustSSLProtocolSocketFactory implements
            SecureProtocolSocketFactory {
        private SSLContext sslcontext;

        public LimitedTrustSSLProtocolSocketFactory(KeyManager[] kms, TrustManager[] trustCerts)
                throws KeyManagementException, NoSuchAlgorithmException

        {
            // try {
            //this.sslcontext = SSLContext.getInstance("SSL");
            this.sslcontext = SSLContext.getInstance("TLS");
            sslcontext.init(kms, trustCerts, null);
            // } catch (Exception e) {
            // logger.error(e.getMessage(), e);
            // throw new RuntimeException( e);
            // }
        }

        //
        private SSLContext getSSLContext() {
            return sslcontext;
        }

        public Socket createSocket(final String host, final int port,
                final InetAddress localAddress, final int localPort,
                final HttpConnectionParams params) throws IOException, UnknownHostException,
                ConnectTimeoutException {
            if (params == null) { throw new IllegalArgumentException("Parameters may not be null"); }
            int timeout = params.getConnectionTimeout();
            if (timeout == 0) {
                return createSocket(host, port, localAddress, localPort);
            } else {
                // To be eventually deprecated when migrated to Java 1.4 or above
                return ControllerThreadSocketFactory.createSocket(this, host, port, localAddress,
                        localPort, timeout);
            }
        }

        /**
         * @see SecureProtocolSocketFactory#createSocket(java.lang.String,int,java.net.InetAddress,int)
         */
        public Socket createSocket(String host, int port, InetAddress clientHost, int clientPort)
                throws IOException, UnknownHostException {
            SSLSocket socket = (SSLSocket) getSSLContext().getSocketFactory().createSocket(
                    host, port, clientHost,clientPort);
            //socket.setUseClientMode(true);
            return socket;
        }

        /**
         * @see SecureProtocolSocketFactory#createSocket(java.lang.String,int)
         */
        public Socket createSocket(String host, int port) throws IOException, UnknownHostException {
            return getSSLContext().getSocketFactory().createSocket(host, port);
        }

        /**
         * @see SecureProtocolSocketFactory#createSocket(java.net.Socket,java.lang.String,int,boolean)
         */
        public Socket createSocket(Socket socket, String host, int port, boolean autoClose)
                throws IOException, UnknownHostException {
            return getSSLContext().getSocketFactory().createSocket(socket, host, port, autoClose);
        }

    }

    public static class TrustAllX509TrustManager implements X509TrustManager {
        public TrustAllX509TrustManager() {
        }

        public java.security.cert.X509Certificate[] getAcceptedIssuers() {
            return null; // new X509Certificate[] {caCert};
        }

        public void checkClientTrusted(java.security.cert.X509Certificate[] certs, String authType) {
        }

        public void checkServerTrusted(java.security.cert.X509Certificate[] certs, String authType) {
            // TODO
        }
    }


    public static class OneChainKeyManager extends X509ExtendedKeyManager {
        private X509Certificate[] certChain;

        private PrivateKey privateKey;

        public OneChainKeyManager(X509Certificate[] certChain, PrivateKey privateKey) {
            if (certChain == null || certChain.length == 0) throw new IllegalArgumentException();
            this.certChain = certChain;
            if (privateKey == null) throw new IllegalArgumentException();
            this.privateKey = privateKey;
        }

        public String chooseClientAlias(String[] arg0, Principal[] arg1, Socket arg2) {
            return "default";
        }

        public String chooseServerAlias(String arg0, Principal[] arg1, Socket arg2) {
            return null; // no server support
        }

        public X509Certificate[] getCertificateChain(String arg0) {
            return certChain;
        }

        public String[] getClientAliases(String arg0, Principal[] arg1) {
            return new String[] { "default" };
        }

        public PrivateKey getPrivateKey(String arg0) {
            return privateKey;
        }

        public String[] getServerAliases(String arg0, Principal[] arg1) {
            return null;
        }
    }

    public static class TrustedListX509TrustManager implements X509TrustManager {
        private X509Certificate[] trustedCerts;

        public TrustedListX509TrustManager(final X509Certificate[] trustedCerts) {
            if(trustedCerts == null || trustedCerts.length == 0) {
                throw new IllegalArgumentException();
            }
            this.trustedCerts = trustedCerts;
        }

        public X509Certificate[] getAcceptedIssuers() {
            return trustedCerts;
        }

        public void checkClientTrusted(X509Certificate[] certs, String authType)
        throws CertificateException
        {
            //System.err.println("called");
            checkTrusted(certs, authType);
        }

        public void checkServerTrusted(X509Certificate[] certs, String authType)
        throws CertificateException {
            checkTrusted(certs, authType);
        }

        private void checkTrusted(X509Certificate[] certs, String authType)
        throws CertificateException {
            if(certs == null || certs.length == 0) {
                throw new CertificateException("missing certificate chain to verify");
            }
            X509Certificate firstCert = certs[0];
            firstCert.checkValidity();
            // TODO verify that first certificate is for host we want to access
            Collection<List<?>> names = firstCert.getSubjectAlternativeNames();
            // check validity of chain
            for (int i = 1; i < certs.length; i++) {
                X509Certificate cert = certs[i];
                cert.checkValidity();
                //Principal principalIssuer = cert.getIssuerDN();
                Principal principalSubject = cert.getSubjectDN();
                X509Certificate prevCert = certs[i - 1];
                Principal prevPrincipalIssuer = prevCert.getIssuerDN();
                if (! principalSubject.equals(prevPrincipalIssuer)) {
                    throw new CertificateException(
                            "certificate chain invalid: issuer "+ prevPrincipalIssuer+
                            " is not the same as next certificate subject "+principalSubject
                            +" (previous="+prevCert+" current="+cert+")");
                }
                try {
                    prevCert.verify(cert.getPublicKey());
                } catch (InvalidKeyException e) {
                    throw new CertificateException(e);
                } catch (NoSuchAlgorithmException e) {
                    throw new CertificateException(e);
                } catch (NoSuchProviderException e) {
                    throw new CertificateException(e);
                } catch (SignatureException e) {
                    throw new CertificateException(e);
                }
            }

            final X509Certificate lastCertificateInChain = certs[ certs.length - 1];
            // verify if top level certificate in chain is signed by trusted cert
            for (int i = 0; i < trustedCerts.length; i++) {
                X509Certificate trustedCert = trustedCerts[i];
                try {
                    trustedCert.checkValidity();
                    lastCertificateInChain.verify( trustedCert.getPublicKey() );
                    return; //verification OK
                } catch (CertificateExpiredException e) {
                    // This happens if one of the trusted certificates has expired.
                    // continue
                } catch (InvalidKeyException e) {
                    // continue
                } catch (NoSuchAlgorithmException e) {
                    // continue
                } catch (NoSuchProviderException e) {
                    // continue
                } catch (SignatureException e) {
                    // continue
                }
            }
            //X500Principal principal = lastCertificateInChain.getIssuerX500Principal();
            throw new CertificateException(
                    "could not find trusted CA to verify "+lastCertificateInChain);
        }
    }

}
