/**
 * JOnAS: Java(TM) Open Application Server
 * Copyright (C) 2008-2009 Bull S.A.S.
 * Contact: jonas-team@ow2.org
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307
 * USA
 *
 * --------------------------------------------------------------------------
 * $Id: Axis2EjbMessageReceiver.java 21566 2011-08-08 12:28:12Z cazauxj $
 * --------------------------------------------------------------------------
 */

package org.ow2.jonas.ws.axis2.easybeans;

import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;

import javax.xml.ws.Binding;
import javax.xml.ws.Provider;

import org.apache.axis2.AxisFault;
import org.apache.axis2.context.OperationContext;
import org.apache.axis2.description.AxisOperation;
import org.apache.axis2.description.AxisService;
import org.apache.axis2.description.Parameter;
import org.apache.axis2.description.WSDL2Constants;
import org.apache.axis2.engine.AxisEngine;
import org.apache.axis2.engine.MessageReceiver;
import org.apache.axis2.jaxws.ExceptionFactory;
import org.apache.axis2.jaxws.core.InvocationContextFactory;
import org.apache.axis2.jaxws.core.MessageContext;
import org.apache.axis2.jaxws.description.EndpointDescription;
import org.apache.axis2.jaxws.description.EndpointInterfaceDescription;
import org.apache.axis2.jaxws.description.OperationDescription;
import org.apache.axis2.jaxws.i18n.Messages;
import org.apache.axis2.jaxws.message.util.MessageUtils;
import org.apache.axis2.jaxws.server.EndpointInvocationContext;
import org.apache.axis2.jaxws.server.JAXWSMessageReceiver;
import org.apache.axis2.jaxws.core.InvocationContext;
import org.apache.axis2.wsdl.WSDLConstants.WSDL20_2004_Constants;
import org.apache.axis2.wsdl.WSDLConstants.WSDL20_2006Constants;
import org.ow2.easybeans.container.session.stateless.StatelessSessionFactory;
import org.ow2.jonas.ws.axis2.jaxws.Axis2WSEndpoint;
import org.ow2.util.log.Log;
import org.ow2.util.log.LogFactory;

/**
 * Substitue axis2's MessageReceiver
 * @author youchao
 * @author xiaoda
 */
public class Axis2EjbMessageReceiver implements MessageReceiver {

    private static final Log LOG = LogFactory.getLog(Axis2EjbMessageReceiver.class);

    private StatelessSessionFactory factory;

    private Class serviceImplClass;

    private Axis2WSEndpoint endpoint;

    public Axis2EjbMessageReceiver(final Axis2WSEndpoint endpoint, final Class serviceImplClass, final StatelessSessionFactory factory) {
        this.endpoint = endpoint;
        this.serviceImplClass = serviceImplClass;
        this.factory = factory;
    }

    public void receive(final org.apache.axis2.context.MessageContext axisMsgCtx) throws AxisFault {
        MessageContext requestMsgCtx = new MessageContext(axisMsgCtx);

        // init some bits
        requestMsgCtx.setOperationName(requestMsgCtx.getAxisMessageContext().getAxisOperation().getName());
        requestMsgCtx.setEndpointDescription(getEndpointDescription(requestMsgCtx));

        Method method = null;
        if (Provider.class.isAssignableFrom(this.serviceImplClass)) {
            method = getProviderMethod();
        } else {
            requestMsgCtx.setOperationDescription(getOperationDescription(requestMsgCtx));
            method = getServiceMethod(requestMsgCtx);
        }

        if (LOG.isDebugEnabled()) {
            LOG.debug("Invoking '" + method.getName() + "' method.");
        }

        AxisOperation operation = requestMsgCtx.getAxisMessageContext().getAxisOperation();
        String mep = operation.getMessageExchangePattern();

        //Use our own controller
        Axis2EndpointController controller = new Axis2EndpointController(factory);
        Binding binding = (Binding) requestMsgCtx.getAxisMessageContext().getProperty(JAXWSMessageReceiver.PARAM_BINDING);
        EndpointInvocationContext ic = InvocationContextFactory.createEndpointInvocationContext(binding);
        ic.setRequestMessageContext(requestMsgCtx);

        // invoke by Axis2EndpointController
        controller.invoke(ic);

        MessageContext responseMsgCtx = ic.getResponseMessageContext();

        // If there is a fault it could be Robust In-Only
        if (!isMepInOnly(mep) || hasFault(responseMsgCtx)) {
            // If this is a two-way exchange, there should already be a
            // JAX-WS MessageContext for the response. We need to pull
            // the Message data out of there and set it on the Axis2
            // MessageContext.
            org.apache.axis2.context.MessageContext axisResponseMsgCtx = responseMsgCtx.getAxisMessageContext();

            MessageUtils.putMessageOnMessageContext(responseMsgCtx.getMessage(), axisResponseMsgCtx);

            OperationContext opCtx = axisResponseMsgCtx.getOperationContext();
            opCtx.addMessageContext(axisResponseMsgCtx);

            // If this is a fault message, we want to throw it as an
            // exception so that the transport can do the appropriate things
            if (responseMsgCtx.getMessage().isFault()) {
                throw new AxisFault("An error was detected during JAXWS processing", axisResponseMsgCtx);
            } else {
                AxisEngine.send(axisResponseMsgCtx);
            }
        }
    }

    private boolean hasFault(final MessageContext responseMsgCtx) {
        if (responseMsgCtx == null || responseMsgCtx.getMessage() == null) {
            return false;
        }
        return responseMsgCtx.getMessage().isFault();
    }

    private boolean isMepInOnly(final String mep) {
        boolean inOnly = mep.equals(WSDL20_2004_Constants.MEP_URI_ROBUST_IN_ONLY)
                || mep.equals(WSDL20_2004_Constants.MEP_URI_IN_ONLY) || mep.equals(WSDL2Constants.MEP_URI_IN_ONLY)
                || mep.equals(WSDL2Constants.MEP_URI_ROBUST_IN_ONLY) || mep.equals(WSDL20_2006Constants.MEP_URI_ROBUST_IN_ONLY)
                || mep.equals(WSDL20_2006Constants.MEP_URI_IN_ONLY);
        return inOnly;
    }

    private Method getServiceMethod(final MessageContext mc) {
        OperationDescription opDesc = mc.getOperationDescription();
        if (opDesc == null) {
            throw ExceptionFactory.makeWebServiceException("Operation Description was not set");
        }

        Method returnMethod = opDesc.getMethodFromServiceImpl(this.serviceImplClass);
        if (returnMethod == null) {
            throw ExceptionFactory.makeWebServiceException(Messages.getMessage("JavaBeanDispatcherErr1"));
        }

        return returnMethod;
    }

    private OperationDescription getOperationDescription(final MessageContext mc) {
        EndpointDescription ed = mc.getEndpointDescription();
        EndpointInterfaceDescription eid = ed.getEndpointInterfaceDescription();

        OperationDescription[] ops = eid.getDispatchableOperation(mc.getOperationName());
        if (ops == null || ops.length == 0) {
            throw ExceptionFactory
                    .makeWebServiceException("No operation found.  WSDL Operation name: " + mc.getOperationName());
        }
        if (ops.length > 1) {
            throw ExceptionFactory
                    .makeWebServiceException("More than one operation found. Overloaded WSDL operations are not supported.  WSDL Operation name: "
                            + mc.getOperationName());
        }
        OperationDescription op = ops[0];
        return op;
    }

    private EndpointDescription getEndpointDescription(final MessageContext mc) {
        AxisService axisSvc = mc.getAxisMessageContext().getAxisService();

        Parameter param = axisSvc.getParameter(EndpointDescription.AXIS_SERVICE_PARAMETER);

        EndpointDescription ed = (EndpointDescription) param.getValue();
        return ed;
    }

    private Method getProviderMethod() {
        try {
            return this.serviceImplClass.getMethod("invoke", getProviderType());
        } catch (NoSuchMethodException e) {
            throw ExceptionFactory.makeWebServiceException("Could not get Provider.invoke() method");
        }
    }

    private Class<?> getProviderType() {
        Class providerType = null;

        Type[] giTypes = this.serviceImplClass.getGenericInterfaces();
        for (Type giType : giTypes) {
            ParameterizedType paramType = null;
            try {
                paramType = (ParameterizedType) giType;
            } catch (ClassCastException e) {
                throw ExceptionFactory
                        .makeWebServiceException("Provider based SEI Class has to implement javax.xml.ws.Provider as javax.xml.ws.Provider<String>, javax.xml.ws.Provider<SOAPMessage>, javax.xml.ws.Provider<Source> or javax.xml.ws.Provider<JAXBContext>");
            }
            Class interfaceName = (Class) paramType.getRawType();

            if (interfaceName == javax.xml.ws.Provider.class) {
                if (paramType.getActualTypeArguments().length > 1) {
                    throw ExceptionFactory
                            .makeWebServiceException("Provider cannot have more than one Generic Types defined as Per JAX-WS Specification");
                }
                providerType = (Class) paramType.getActualTypeArguments()[0];
            }
        }
        return providerType;
    }

}
