001/**
002 *   GRANITE DATA SERVICES
003 *   Copyright (C) 2006-2013 GRANITE DATA SERVICES S.A.S.
004 *
005 *   This file is part of the Granite Data Services Platform.
006 *
007 *   Granite Data Services is free software; you can redistribute it and/or
008 *   modify it under the terms of the GNU Lesser General Public
009 *   License as published by the Free Software Foundation; either
010 *   version 2.1 of the License, or (at your option) any later version.
011 *
012 *   Granite Data Services is distributed in the hope that it will be useful,
013 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
014 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
015 *   General Public License for more details.
016 *
017 *   You should have received a copy of the GNU Lesser General Public
018 *   License along with this library; if not, write to the Free Software
019 *   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
020 *   USA, or see <http://www.gnu.org/licenses/>.
021 */
022package org.granite.tide.spring;
023
024import java.lang.reflect.Method;
025import java.util.ArrayList;
026import java.util.Enumeration;
027import java.util.HashMap;
028import java.util.Hashtable;
029import java.util.List;
030import java.util.Map;
031import java.util.Set;
032
033import javax.servlet.ServletRequest;
034import javax.servlet.http.HttpServletRequest;
035import javax.servlet.http.HttpServletRequestWrapper;
036import javax.servlet.http.HttpServletResponse;
037
038import org.granite.context.GraniteContext;
039import org.granite.logging.Logger;
040import org.granite.messaging.amf.io.convert.Converter;
041import org.granite.messaging.amf.io.util.ClassGetter;
042import org.granite.messaging.service.ServiceException;
043import org.granite.messaging.service.ServiceInvocationContext;
044import org.granite.messaging.webapp.HttpGraniteContext;
045import org.granite.tide.IInvocationCall;
046import org.granite.tide.IInvocationResult;
047import org.granite.tide.annotations.BypassTideMerge;
048import org.granite.tide.data.DataContext;
049import org.granite.tide.invocation.ContextUpdate;
050import org.granite.tide.invocation.InvocationCall;
051import org.granite.tide.invocation.InvocationResult;
052import org.granite.util.TypeUtil;
053import org.springframework.beans.TypeMismatchException;
054import org.springframework.beans.factory.NoSuchBeanDefinitionException;
055import org.springframework.context.ApplicationContext;
056import org.springframework.core.MethodParameter;
057import org.springframework.web.bind.ServletRequestDataBinder;
058import org.springframework.web.context.request.WebRequestInterceptor;
059import org.springframework.web.servlet.HandlerAdapter;
060import org.springframework.web.servlet.HandlerInterceptor;
061import org.springframework.web.servlet.ModelAndView;
062import org.springframework.web.servlet.handler.WebRequestHandlerInterceptorAdapter;
063import org.springframework.web.servlet.mvc.Controller;
064import org.springframework.web.servlet.mvc.SimpleControllerHandlerAdapter;
065import org.springframework.web.servlet.mvc.annotation.AnnotationMethodHandlerAdapter;
066
067
068/**
069 * @author William DRAI
070 */
071public class SpringMVCServiceContext extends SpringServiceContext {
072
073    private static final long serialVersionUID = 1L;
074    
075    private static final String REQUEST_VALUE = "__REQUEST_VALUE__";
076    
077    private static final Logger log = Logger.getLogger(SpringMVCServiceContext.class);
078
079    
080    public SpringMVCServiceContext() throws ServiceException {
081        super();
082    }
083    
084    public SpringMVCServiceContext(ApplicationContext springContext) throws ServiceException {
085        super(springContext);
086    }
087    
088    
089    @Override
090    public Object adjustInvokee(Object instance, String componentName, Set<Class<?>> componentClasses) {
091        for (Class<?> componentClass : componentClasses) {
092                if (componentClass.isAnnotationPresent(org.springframework.stereotype.Controller.class)) {
093                        return new AnnotationMethodHandlerAdapter() {
094                                @Override
095                                protected ServletRequestDataBinder createBinder(HttpServletRequest request, Object target, String objectName) throws Exception {
096                                        return new ControllerRequestDataBinder(request, target, objectName);
097                                }
098                        };
099                }
100        }
101        if (Controller.class.isInstance(instance) || (componentName != null && componentName.endsWith("Controller")))
102                return new SimpleControllerHandlerAdapter();
103        
104        return instance;
105    }
106    
107    @Override
108    protected Object internalFindComponent(String componentName, Class<?> componentClass) {
109        try {
110                return super.internalFindComponent(componentName, componentClass);
111        }
112        catch (NoSuchBeanDefinitionException e) {
113//              Map<String, HandlerMapping> handlerMappingsMap = springContext.getBeansOfType(HandlerMapping.class);
114//              if (handlerMappingsMap)
115                if (componentName != null && componentName.endsWith("Controller")) {
116                        try {
117                                int idx = componentName.lastIndexOf(".");
118                                String controllerName = idx > 0 
119                                        ? componentName.substring(0, idx+1) + componentName.substring(idx+1, idx+2).toUpperCase() + componentName.substring(idx+2)
120                                        : componentName.substring(0, 1).toUpperCase() + componentName.substring(1);
121                                        
122                                return springContext.getBean(controllerName);
123                        }
124                        catch (NoSuchBeanDefinitionException nexc2) {
125                        }
126                }
127        }
128        
129        return null;            
130    }
131    
132    
133    
134    private static final String SPRINGMVC_BINDING_ATTR = "__SPRINGMVC_LOCAL_BINDING__";
135    
136    @Override
137    public Object[] beforeMethodSearch(Object instance, String methodName, Object[] args) {
138        if (instance instanceof HandlerAdapter) {
139                boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
140                
141                String componentName = (String)args[0];
142                String componentClassName = (String)args[1];
143            Class<?> componentClass = null;
144            try {
145                if (componentClassName != null)
146                        componentClass = TypeUtil.forName(componentClassName);
147            }
148            catch (ClassNotFoundException e) {
149                throw new ServiceException("Component class not found " + componentClassName, e);
150            }
151                Object component = findComponent(componentName, componentClass);
152                Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
153                Object handler = component;
154                if (grails && componentName.endsWith("Controller")) {
155                        // Special handling for Grails controllers
156                        handler = springContext.getBean("mainSimpleController");
157                }
158                HttpGraniteContext context = (HttpGraniteContext)GraniteContext.getCurrentInstance();
159                @SuppressWarnings("unchecked")
160                Map<String, Object> requestMap = (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length >= 1 && ((Object[])args[3]).length <= 2 && ((Object[])args[3])[0] instanceof Map) 
161                        ? (Map<String, Object>)((Object[])args[3])[0] 
162                        : null;
163                boolean localBinding = false;
164                if (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length == 2 
165                                && ((Object[])args[3])[0] instanceof Map<?, ?> && ((Object[])args[3])[1] instanceof Boolean)
166                        localBinding = (Boolean)((Object[])args[3])[1];
167                context.getRequestMap().put(SPRINGMVC_BINDING_ATTR, localBinding);
168                
169                Map<String, Object> valueMap = null;
170                if (args[4] instanceof InvocationCall) {
171                        valueMap = new HashMap<String, Object>();
172                        for (ContextUpdate u : ((InvocationCall)args[4]).getUpdates())
173                                valueMap.put(u.getComponentName() + (u.getExpression() != null ? "." + u.getExpression() : ""), u.getValue());
174                }
175                
176                if (grails) {
177                                // Special handling for Grails controllers
178                        try {
179                                for (Class<?> cClass : componentClasses) {
180                                        if (cClass.isInterface())
181                                                continue;
182                                        Method m = cClass.getDeclaredMethod("getProperty", String.class);
183                                        @SuppressWarnings("unchecked")
184                                        Map<String, Object> map = (Map<String, Object>)m.invoke(component, "params");
185                                        if (requestMap != null)
186                                                map.putAll(requestMap);
187                                        if (valueMap != null)
188                                                map.putAll(valueMap);
189                                }
190                        }
191                        catch (Exception e) {
192                                // Ignore, probably not a Grails controller
193                        }
194                }
195                ControllerRequestWrapper rw = new ControllerRequestWrapper(grails, context.getRequest(), componentName, (String)args[2], requestMap, valueMap);
196                return new Object[] { "handle", new Object[] { rw, context.getResponse(), handler }};
197        }
198        
199        return super.beforeMethodSearch(instance, methodName, args);
200    }
201    
202    
203    @Override
204    public void prepareCall(ServiceInvocationContext context, IInvocationCall c, String componentName, Class<?> componentClass) {
205        super.prepareCall(context, c, componentName, componentClass);
206                
207        if (componentName == null)
208                return;
209        
210                Object component = findComponent(componentName, componentClass);
211                
212                if (context.getBean() instanceof HandlerAdapter) {
213                        // In case of Spring controllers, call interceptors
214                ApplicationContext webContext = getSpringContext();
215                String[] interceptorNames = webContext.getBeanNamesForType(HandlerInterceptor.class);
216                String[] webRequestInterceptors = webContext.getBeanNamesForType(WebRequestInterceptor.class);
217                HandlerInterceptor[] interceptors = new HandlerInterceptor[interceptorNames.length+webRequestInterceptors.length];
218        
219                int j = 0;
220                for (int i = 0; i < webRequestInterceptors.length; i++)
221                    interceptors[j++] = new WebRequestHandlerInterceptorAdapter((WebRequestInterceptor)webContext.getBean(webRequestInterceptors[i]));
222                for (int i = 0; i < interceptorNames.length; i++)
223                    interceptors[j++] = (HandlerInterceptor)webContext.getBean(interceptorNames[i]);
224                
225                        HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
226                        
227                        graniteContext.getRequestMap().put(HandlerInterceptor.class.getName(), interceptors);
228                        
229                        try {
230                        for (int i = 0; i < interceptors.length; i++) {
231                            HandlerInterceptor interceptor = interceptors[i];
232                            interceptor.preHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component);
233                        }
234                        }
235                        catch (Exception e) {
236                                throw new ServiceException(e.getMessage(), e);
237                        }
238                }
239    }
240    
241    
242    @Override
243    @SuppressWarnings("unchecked")
244    public IInvocationResult postCall(ServiceInvocationContext context, Object result, String componentName, Class<?> componentClass) {
245        List<ContextUpdate> results = null;
246        
247        Object component = null;
248        if (componentName != null && context.getBean() instanceof HandlerAdapter) {
249                        component = findComponent(componentName, componentClass);
250                        
251                        HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
252                        
253                Map<String, Object> modelMap = null;
254                if (result instanceof ModelAndView) {
255                        ModelAndView modelAndView = (ModelAndView)result;
256                        modelMap = modelAndView.getModel();
257                        result = modelAndView.getViewName();
258                        
259                        if (context.getBean() instanceof HandlerAdapter) {
260                                try {
261                                        HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
262                                        
263                                        if (interceptors != null) {
264                                        for (int i = interceptors.length-1; i >= 0; i--) {
265                                            HandlerInterceptor interceptor = interceptors[i];
266                                            interceptor.postHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component, modelAndView);
267                                        }
268        
269                                            triggerAfterCompletion(component, interceptors.length-1, interceptors, graniteContext.getRequest(), graniteContext.getResponse(), null);
270                                        }
271                                }
272                                catch (Exception e) {
273                                        throw new ServiceException(e.getMessage(), e);
274                                }
275                        }
276                }
277                
278                if (modelMap != null) {
279                        Boolean localBinding = (Boolean)graniteContext.getRequestMap().get(SPRINGMVC_BINDING_ATTR);
280                        
281                        results = new ArrayList<ContextUpdate>();
282                        for (Map.Entry<String, Object> me : modelMap.entrySet()) {
283                                        if (me.getKey().toString().startsWith("org.springframework.validation.")
284                                                        || (me.getValue() != null && (
285                                                                        me.getValue().getClass().getName().startsWith("groovy.lang.ExpandoMetaClass")
286                                                                        || me.getValue().getClass().getName().indexOf("$_closure") > 0
287                                                                        || me.getValue() instanceof Class)))                                                                    
288                                                continue;
289                                        String variableName = me.getKey().toString();
290                                        if (Boolean.TRUE.equals(localBinding))
291                                                results.add(new ContextUpdate(componentName, variableName, me.getValue(), 3, false));
292                                        else
293                                                results.add(new ContextUpdate(variableName, null, me.getValue(), 3, false));
294                        }
295                }
296                
297                        boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
298                        if (grails) {
299                                // Special handling for Grails controllers: get flash content
300                                try {
301                                Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
302                                for (Class<?> cClass : componentClasses) {
303                                        if (cClass.isInterface())
304                                                continue;
305                                                Method m = cClass.getDeclaredMethod("getProperty", String.class);
306                                                Map<String, Object> map = (Map<String, Object>)m.invoke(component, "flash");
307                                                if (results == null)
308                                                        results = new ArrayList<ContextUpdate>();
309                                                for (Map.Entry<String, Object> me : map.entrySet()) {
310                                                        Object value = me.getValue();
311                                                        if (value != null && value.getClass().getName().startsWith("org.codehaus.groovy.runtime.GString"))
312                                                                value = value.toString();
313                                                        results.add(new ContextUpdate("flash", me.getKey(), value, 3, false));
314                                                }
315                                }
316                                }
317                                catch (Exception e) {
318                                        throw new ServiceException("Flash scope retrieval failed", e);
319                                }
320                        }
321        }
322                
323        DataContext dataContext = DataContext.get();
324                Object[][] updates = dataContext != null ? dataContext.getUpdates() : null;
325                
326        InvocationResult ires = new InvocationResult(result, results);
327        if (component == null)
328                component = context.getBean();
329        if (isBeanAnnotationPresent(component, BypassTideMerge.class))
330                ires.setMerge(false);
331        else if (!(context.getParameters().length > 0 && context.getParameters()[0] instanceof ControllerRequestWrapper)) {
332                if (isBeanMethodAnnotationPresent(component, context.getMethod().getName(), context.getMethod().getParameterTypes(), BypassTideMerge.class))
333                        ires.setMerge(false);
334        }
335        
336        ires.setUpdates(updates);
337        
338        return ires;
339    }
340
341    @Override
342    public void postCallFault(ServiceInvocationContext context, Throwable t, String componentName, Class<?> componentClass) {
343        if (componentName != null && context.getBean() instanceof HandlerAdapter) {
344                        HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
345                        
346                        Object component = findComponent(componentName, componentClass);
347                        
348                        HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
349        
350                triggerAfterCompletion(component, interceptors.length-1, interceptors, 
351                                graniteContext.getRequest(), graniteContext.getResponse(), 
352                                t instanceof Exception ? (Exception)t : null);
353        }
354                
355        super.postCallFault(context, t, componentName, componentClass);
356    }
357
358    
359    private void triggerAfterCompletion(Object component, int interceptorIndex, HandlerInterceptor[] interceptors, HttpServletRequest request, HttpServletResponse response, Exception ex) {
360                for (int i = interceptorIndex; i >= 0; i--) {
361                        HandlerInterceptor interceptor = interceptors[i];
362                        try {
363                                interceptor.afterCompletion(request, response, component, ex);
364                        }
365                        catch (Throwable ex2) {
366                                log.error("HandlerInterceptor.afterCompletion threw exception", ex2);
367                        }
368                }
369    }
370
371    
372    
373    private class ControllerRequestWrapper extends HttpServletRequestWrapper {
374        private String initialComponentName = null;
375        private String componentName = null;
376        private String methodName = null;
377        private Map<String, Object> requestMap = null;
378        private Map<String, Object> valueMap = null;
379        private boolean localBinding;
380        
381                public ControllerRequestWrapper(boolean grails, HttpServletRequest request, String componentName, String methodName, Map<String, Object> requestMap, Map<String, Object> valueMap) {
382                        super(request);
383                        this.initialComponentName = componentName;
384                        this.componentName = componentName.substring(0, componentName.length()-"Controller".length());
385                        if (this.componentName.indexOf(".") > 0)
386                                this.componentName = this.componentName.substring(this.componentName.lastIndexOf(".")+1);
387                        if (grails)
388                                this.componentName = this.componentName.substring(0, 1).toLowerCase() + this.componentName.substring(1);
389                        this.methodName = methodName;
390                        this.requestMap = requestMap;
391                        this.valueMap = valueMap;
392                        this.localBinding = Boolean.TRUE.equals(request.getAttribute(SPRINGMVC_BINDING_ATTR));
393                }
394        
395                @Override
396                public String getRequestURI() {
397                        return getContextPath() + "/" + componentName + "/" + methodName;
398                }
399                
400                @Override
401                public String getServletPath() {
402                        return "/" + componentName + "/" + methodName;
403                }
404                
405                public Object getRequestValue(String key) {
406                        return requestMap != null ? requestMap.get(key) : null;
407                }
408                
409                public Object getBindValue(String key) {
410                        if (valueMap == null)
411                                return null;
412                        
413                        return localBinding && valueMap.containsKey(initialComponentName + "." + key)
414                                        ? valueMap.get(initialComponentName + "." + key)
415                                        : valueMap.get(key);
416                }
417                
418                @Override
419                public String getParameter(String name) {
420                        return requestMap != null && requestMap.containsKey(name) ? REQUEST_VALUE : null;
421                }
422                
423                @Override
424                public String[] getParameterValues(String name) {
425                        return requestMap != null && requestMap.containsKey(name) ? new String[] { REQUEST_VALUE } : null;
426                }
427
428                @Override
429                @SuppressWarnings({ "unchecked", "rawtypes" })
430                public Map getParameterMap() {
431                        Map<String, Object> pmap = new HashMap<String, Object>();
432                        if (requestMap != null) {
433                                for (String name : requestMap.keySet())
434                                        pmap.put(name, REQUEST_VALUE);
435                        }
436                        return pmap; 
437                }
438
439                @Override
440                @SuppressWarnings({ "unchecked", "rawtypes" })
441                public Enumeration getParameterNames() {
442                        Hashtable ht = new Hashtable();
443                        if (requestMap != null)
444                                ht.putAll(requestMap);
445                        return ht.keys();
446                }
447    }
448    
449    
450    private class ControllerRequestDataBinder extends ServletRequestDataBinder {
451        
452        private ControllerRequestWrapper wrapper = null;
453        private Object target = null;
454
455                public ControllerRequestDataBinder(ServletRequest request, Object target, String objectName) {
456                        super(target, objectName);
457                        this.wrapper = (ControllerRequestWrapper)request;
458                        this.target = target;
459                }
460                
461                private Object getBindValue(boolean request, Class<?> requiredType) {
462                        GraniteContext context = GraniteContext.getCurrentInstance();
463                        ClassGetter classGetter = context.getGraniteConfig().getClassGetter();
464                        Object value = request ? wrapper.getRequestValue(getObjectName()) : wrapper.getBindValue(getObjectName());
465                        if (requiredType != null) {
466                                Converter converter = context.getGraniteConfig().getConverters().getConverter(value, requiredType);
467                                if (converter != null)
468                                        value = converter.convert(value, requiredType);
469                        }
470                        if (value != null && !request)
471                                return SpringMVCServiceContext.this.mergeExternal(classGetter, value, null, null, null);
472                        return value;
473                }
474
475                @Override
476                public void bind(ServletRequest request) {
477                        Object value = getBindValue(false, null);
478                        if (value != null)
479                                target = value;
480                }
481                
482                @Override
483                public Object getTarget() {
484                        return target;                  
485                }
486                
487                @SuppressWarnings({ "rawtypes", "unchecked" })
488                @Override
489                public Object convertIfNecessary(Object value, Class requiredType, MethodParameter methodParam) throws TypeMismatchException {
490                        if (target == null && value == REQUEST_VALUE || (value instanceof String[] && ((String[])value)[0] == REQUEST_VALUE))
491                                return getBindValue(true, requiredType);
492                        return super.convertIfNecessary(value, requiredType, methodParam);
493                }
494    }
495}