/**
 * JOnAS: Java(TM) Open Application Server
 * Copyright (C) 2008-2009 Bull S.A.S.
 * Contact: jonas-team@ow2.org
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307
 * USA
 *
 * --------------------------------------------------------------------------
 * $Id: WSDLQueryHandler.java 21566 2011-08-08 12:28:12Z cazauxj $
 * --------------------------------------------------------------------------
 */
package org.ow2.jonas.ws.axis2.http;

import java.io.FileNotFoundException;
import java.io.OutputStream;
import java.net.MalformedURLException;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import javax.wsdl.Definition;
import javax.wsdl.Import;
import javax.wsdl.Port;
import javax.wsdl.Service;
import javax.wsdl.Types;
import javax.wsdl.extensions.ExtensibilityElement;
import javax.wsdl.extensions.schema.Schema;
import javax.wsdl.extensions.schema.SchemaImport;
import javax.wsdl.extensions.schema.SchemaReference;
import javax.wsdl.extensions.soap.SOAPAddress;
import javax.wsdl.extensions.soap12.SOAP12Address;
import javax.wsdl.factory.WSDLFactory;
import javax.wsdl.xml.WSDLReader;
import javax.wsdl.xml.WSDLWriter;
import javax.xml.namespace.QName;
import javax.xml.transform.OutputKeys;
import javax.xml.transform.Source;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;

import org.apache.axis2.description.AxisService;
import org.ow2.jonas.ws.axis2.util.WSDLUtils;
import org.ow2.util.log.Log;
import org.ow2.util.log.LogFactory;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

public class WSDLQueryHandler {

    private static final Log LOG = LogFactory.getLog(WSDLQueryHandler.class);

    private Map<String, Definition> mp = new ConcurrentHashMap<String, Definition>();

    private Map<String, SchemaReference> smp = new ConcurrentHashMap<String, SchemaReference>();

    private AxisService service;

    public WSDLQueryHandler(final AxisService service) {
        this.service = service;
    }

    public void writeResponse(final String baseUri, final String wsdlUri, final OutputStream os) throws Exception {

        int idx = baseUri.toLowerCase().indexOf("?wsdl");
        String base = null;
        String wsdl = "";
        String xsd = null;
        if (idx != -1) {
            base = baseUri.substring(0, baseUri.toLowerCase().indexOf("?wsdl"));
            wsdl = baseUri.substring(baseUri.toLowerCase().indexOf("?wsdl") + 5);
            if (wsdl.length() > 0) {
                wsdl = wsdl.substring(1);
            }
        } else {
            base = baseUri.substring(0, baseUri.toLowerCase().indexOf("?xsd="));
            xsd = baseUri.substring(baseUri.toLowerCase().indexOf("?xsd=") + 5);
        }

        if (!mp.containsKey(wsdl)) {
            WSDLFactory factory = WSDLFactory.newInstance();
            WSDLReader reader = factory.newWSDLReader();
            reader.setFeature("javax.wsdl.importDocuments", true);
            reader.setFeature("javax.wsdl.verbose", false);
            Definition def = reader.readWSDL(wsdlUri);
            updateDefinition(def, mp, smp, base);
            //updateServices(this.service.getName(), this.service.getEndpointName(), def, base);
            WSDLUtils.trimDefinition(def, this.service.getName(), this.service.getEndpointName());
            mp.put("", def);
        }

        Element rootElement;

        if (xsd == null) {
            Definition def = mp.get(wsdl);

            if (def == null) {
                throw new FileNotFoundException("WSDL not found: " + wsdl);
            }

            // update service port location on each request
            if (wsdl.equals("")) {
                WSDLUtils.updateLocations(def, base);
            }

            WSDLFactory factory = WSDLFactory.newInstance();
            WSDLWriter writer = factory.newWSDLWriter();

            rootElement = writer.getDocument(def).getDocumentElement();
        } else {
            SchemaReference si = smp.get(xsd);

            if (si == null) {
                throw new FileNotFoundException("Schema not found: " + xsd);
            }

            rootElement = si.getReferencedSchema().getElement();
        }

        NodeList nl = rootElement.getElementsByTagNameNS("http://www.w3.org/2001/XMLSchema",
                "import");
        for (int x = 0; x < nl.getLength(); x++) {
            Element el = (Element) nl.item(x);
            String sl = el.getAttribute("schemaLocation");
            if (smp.containsKey(sl)) {
                el.setAttribute("schemaLocation", base + "?xsd=" + sl);
            }
        }
        nl = rootElement.getElementsByTagNameNS("http://www.w3.org/2001/XMLSchema", "include");
        for (int x = 0; x < nl.getLength(); x++) {
            Element el = (Element) nl.item(x);
            String sl = el.getAttribute("schemaLocation");
            if (smp.containsKey(sl)) {
                el.setAttribute("schemaLocation", base + "?xsd=" + sl);
            }
        }
        nl = rootElement.getElementsByTagNameNS("http://schemas.xmlsoap.org/wsdl/", "import");
        for (int x = 0; x < nl.getLength(); x++) {
            Element el = (Element) nl.item(x);
            String sl = el.getAttribute("location");
            if (mp.containsKey(sl)) {
                el.setAttribute("location", base + "?wsdl=" + sl);
            }
        }

        writeTo(rootElement, os);
/*
        Element rootElement;

        if (xsd == null) {
            Definition def = mp.get(wsdl);

            WSDLFactory factory = WSDLFactory.newInstance();
            WSDLWriter writer = factory.newWSDLWriter();

            rootElement = writer.getDocument(def).getDocumentElement();
        } else {
            SchemaReference si = smp.get(xsd);
            rootElement = si.getReferencedSchema().getElement();
        }

        NodeList nl = rootElement.getElementsByTagNameNS("http://www.w3.org/2001/XMLSchema", "import");
        for (int x = 0; x < nl.getLength(); x++) {
            Element el = (Element) nl.item(x);
            String sl = el.getAttribute("schemaLocation");
            if (smp.containsKey(sl)) {
                el.setAttribute("schemaLocation", base + "?xsd=" + sl);
            }
        }
        nl = rootElement.getElementsByTagNameNS("http://www.w3.org/2001/XMLSchema", "include");
        for (int x = 0; x < nl.getLength(); x++) {
            Element el = (Element) nl.item(x);
            String sl = el.getAttribute("schemaLocation");
            if (smp.containsKey(sl)) {
                el.setAttribute("schemaLocation", base + "?xsd=" + sl);
            }
        }
        nl = rootElement.getElementsByTagNameNS("http://schemas.xmlsoap.org/wsdl/", "import");
        for (int x = 0; x < nl.getLength(); x++) {
            Element el = (Element) nl.item(x);
            String sl = el.getAttribute("location");
            if (mp.containsKey(sl)) {
                el.setAttribute("location", base + "?wsdl=" + sl);
            }
        }

        writeTo(rootElement, os);
*/    }

    protected void updateDefinition(final Definition def, final Map<String, Definition> done, final Map<String, SchemaReference> doneSchemas,
            final String base) {
        Collection<List> imports = def.getImports().values();
        for (List lst : imports) {
            List<Import> impLst = lst;
            for (Import imp : impLst) {
                String start = imp.getLocationURI();
                try {
                    // check to see if it's aleady in a URL format. If so, leave
                    // it.
                    new URL(start);
                } catch (MalformedURLException e) {
                    done.put(start, imp.getDefinition());
                    updateDefinition(imp.getDefinition(), done, doneSchemas, base);
                }
            }
        }

        /*
         * This doesn't actually work. Setting setSchemaLocationURI on the
         * import for some reason doesn't actually result in the new URI being
         * written
         */
        Types types = def.getTypes();
        if (types != null) {
            for (ExtensibilityElement el : (List<ExtensibilityElement>) types.getExtensibilityElements()) {
                if (el instanceof Schema) {
                    Schema see = (Schema) el;
                    updateSchemaImports(see, doneSchemas, base);
                }
            }
        }
    }

    protected void updateSchemaImports(final Schema schema, final Map<String, SchemaReference> doneSchemas, final String base) {
        Collection<List> imports = schema.getImports().values();
        for (List lst : imports) {
            List<SchemaImport> impLst = lst;
            for (SchemaImport imp : impLst) {
                String start = imp.getSchemaLocationURI();
                if (start != null) {
                    try {
                        // check to see if it's aleady in a URL format. If so,
                        // leave it.
                        new URL(start);
                    } catch (MalformedURLException e) {
                        if (!doneSchemas.containsKey(start)) {
                            doneSchemas.put(start, imp);
                            updateSchemaImports(imp.getReferencedSchema(), doneSchemas, base);
                        }
                    }
                }
            }
        }
        List<SchemaReference> includes = schema.getIncludes();
        for (SchemaReference included : includes) {
            String start = included.getSchemaLocationURI();
            if (start != null) {
                try {
                    // check to see if it's aleady in a URL format. If so, leave
                    // it.
                    new URL(start);
                } catch (MalformedURLException e) {
                    if (!doneSchemas.containsKey(start)) {
                        doneSchemas.put(start, included);
                        updateSchemaImports(included.getReferencedSchema(), doneSchemas, base);
                    }
                }
            }
        }
    }

    public static void writeTo(final Node node, final OutputStream os) {
        writeTo(new DOMSource(node), os);
    }

    public static void writeTo(final Source src, final OutputStream os) {
        Transformer it;
        try {
            it = TransformerFactory.newInstance().newTransformer();
            it.setOutputProperty(OutputKeys.METHOD, "xml");
            it.setOutputProperty(OutputKeys.INDENT, "yes");
            it.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "4");
            it.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "false");
            it.setOutputProperty(OutputKeys.ENCODING, "utf-8");
            it.transform(src, new StreamResult(os));
        } catch (TransformerException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    private void updateServices(final String serviceName, final String portName, final Definition def, final String baseUri)
            throws Exception {
        boolean updated = false;
        Map services = def.getServices();
        if (services != null) {
            ArrayList<QName> servicesToRemove = new ArrayList<QName>();

            Iterator serviceIterator = services.entrySet().iterator();
            while (serviceIterator.hasNext()) {
                Map.Entry serviceEntry = (Map.Entry) serviceIterator.next();
                QName currServiceName = (QName) serviceEntry.getKey();
                if (currServiceName.getLocalPart().equals(serviceName)) {
                    Service service = (Service) serviceEntry.getValue();
                    updatePorts(portName, service, baseUri);
                    updated = true;
                } else {
                    servicesToRemove.add(currServiceName);
                }
            }

            for (QName serviceToRemove : servicesToRemove) {
                def.removeService(serviceToRemove);
            }
        }
        if (!updated) {
            LOG.warn("WSDL '" + serviceName + "' service not found.");
        }
    }

    private void updatePorts(final String portName, final Service service, final String baseUri) throws Exception {
        boolean updated = false;
        Map ports = service.getPorts();
        if (ports != null) {
            ArrayList<String> portsToRemove = new ArrayList<String>();

            Iterator portIterator = ports.entrySet().iterator();
            while (portIterator.hasNext()) {
                Map.Entry portEntry = (Map.Entry) portIterator.next();
                String currPortName = (String) portEntry.getKey();
                if (currPortName.equals(portName)) {
                    Port port = (Port) portEntry.getValue();
                    updatePortLocation(port, baseUri);
                    updated = true;
                } else {
                    portsToRemove.add(currPortName);
                }
            }

            for (String portToRemove : portsToRemove) {
                service.removePort(portToRemove);
            }
        }
        if (!updated) {
            LOG.warn("WSDL '" + portName + "' port not found.");
        }
    }

    private void updatePortLocation(final Port port, final String baseUri) throws URISyntaxException {
        List<?> exts = port.getExtensibilityElements();
        if (exts != null && exts.size() > 0) {
            ExtensibilityElement el = (ExtensibilityElement) exts.get(0);
            if (el instanceof SOAP12Address) {
                SOAP12Address add = (SOAP12Address) el;
                add.setLocationURI(baseUri);
            } else if (el instanceof SOAPAddress) {
                SOAPAddress add = (SOAPAddress) el;
                add.setLocationURI(baseUri);
            }
        }
    }
}
