001/*
002  GRANITE DATA SERVICES
003  Copyright (C) 2011 GRANITE DATA SERVICES S.A.S.
004
005  This file is part of Granite Data Services.
006
007  Granite Data Services is free software; you can redistribute it and/or modify
008  it under the terms of the GNU Library General Public License as published by
009  the Free Software Foundation; either version 2 of the License, or (at your
010  option) any later version.
011
012  Granite Data Services is distributed in the hope that it will be useful, but
013  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
014  FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License
015  for more details.
016
017  You should have received a copy of the GNU Library General Public License
018  along with this library; if not, see <http://www.gnu.org/licenses/>.
019*/
020
021package org.granite.tide.spring;
022
023import java.lang.reflect.Method;
024import java.util.ArrayList;
025import java.util.Enumeration;
026import java.util.HashMap;
027import java.util.Hashtable;
028import java.util.List;
029import java.util.Map;
030import java.util.Set;
031
032import javax.servlet.ServletRequest;
033import javax.servlet.http.HttpServletRequest;
034import javax.servlet.http.HttpServletRequestWrapper;
035import javax.servlet.http.HttpServletResponse;
036
037import org.granite.context.GraniteContext;
038import org.granite.logging.Logger;
039import org.granite.messaging.amf.io.convert.Converter;
040import org.granite.messaging.amf.io.util.ClassGetter;
041import org.granite.messaging.service.ServiceException;
042import org.granite.messaging.service.ServiceInvocationContext;
043import org.granite.messaging.webapp.HttpGraniteContext;
044import org.granite.tide.IInvocationCall;
045import org.granite.tide.IInvocationResult;
046import org.granite.tide.annotations.BypassTideMerge;
047import org.granite.tide.data.DataContext;
048import org.granite.tide.invocation.ContextUpdate;
049import org.granite.tide.invocation.InvocationCall;
050import org.granite.tide.invocation.InvocationResult;
051import org.granite.util.TypeUtil;
052import org.springframework.beans.TypeMismatchException;
053import org.springframework.context.ApplicationContext;
054import org.springframework.core.MethodParameter;
055import org.springframework.web.bind.ServletRequestDataBinder;
056import org.springframework.web.context.request.WebRequestInterceptor;
057import org.springframework.web.servlet.HandlerAdapter;
058import org.springframework.web.servlet.HandlerInterceptor;
059import org.springframework.web.servlet.ModelAndView;
060import org.springframework.web.servlet.handler.WebRequestHandlerInterceptorAdapter;
061import org.springframework.web.servlet.mvc.Controller;
062import org.springframework.web.servlet.mvc.SimpleControllerHandlerAdapter;
063import org.springframework.web.servlet.mvc.annotation.AnnotationMethodHandlerAdapter;
064
065
066/**
067 * @author William DRAI
068 */
069public class SpringMVCServiceContext extends SpringServiceContext {
070
071    private static final long serialVersionUID = 1L;
072    
073    private static final String REQUEST_VALUE = "__REQUEST_VALUE__";
074    
075    private static final Logger log = Logger.getLogger(SpringMVCServiceContext.class);
076
077    
078    public SpringMVCServiceContext() throws ServiceException {
079        super();
080    }
081    
082    public SpringMVCServiceContext(ApplicationContext springContext) throws ServiceException {
083        super(springContext);
084    }
085    
086    
087    @Override
088    public Object adjustInvokee(Object instance, String componentName, Set<Class<?>> componentClasses) {
089        for (Class<?> componentClass : componentClasses) {
090                if (componentClass.isAnnotationPresent(org.springframework.stereotype.Controller.class)) {
091                        return new AnnotationMethodHandlerAdapter() {
092                                @Override
093                                protected ServletRequestDataBinder createBinder(HttpServletRequest request, Object target, String objectName) throws Exception {
094                                        return new ControllerRequestDataBinder(request, target, objectName);
095                                }
096                        };
097                }
098        }
099        if (Controller.class.isInstance(instance) || (componentName != null && componentName.endsWith("Controller")))
100                return new SimpleControllerHandlerAdapter();
101        
102        return instance;
103    }
104    
105    
106    private static final String SPRINGMVC_BINDING_ATTR = "__SPRINGMVC_LOCAL_BINDING__";
107    
108    @Override
109    public Object[] beforeMethodSearch(Object instance, String methodName, Object[] args) {
110        if (instance instanceof HandlerAdapter) {
111                boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
112                
113                String componentName = (String)args[0];
114                String componentClassName = (String)args[1];
115            Class<?> componentClass = null;
116            try {
117                if (componentClassName != null)
118                        componentClass = TypeUtil.forName(componentClassName);
119            }
120            catch (ClassNotFoundException e) {
121                throw new ServiceException("Component class not found " + componentClassName, e);
122            }
123                Object component = findComponent(componentName, componentClass);
124                Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
125                Object handler = component;
126                if (grails && componentName.endsWith("Controller")) {
127                        // Special handling for Grails controllers
128                        handler = springContext.getBean("mainSimpleController");
129                }
130                HttpGraniteContext context = (HttpGraniteContext)GraniteContext.getCurrentInstance();
131                @SuppressWarnings("unchecked")
132                Map<String, Object> requestMap = (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length >= 1 && ((Object[])args[3]).length <= 2 && ((Object[])args[3])[0] instanceof Map) 
133                        ? (Map<String, Object>)((Object[])args[3])[0] 
134                        : null;
135                boolean localBinding = false;
136                if (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length == 2 
137                                && ((Object[])args[3])[0] instanceof Map<?, ?> && ((Object[])args[3])[1] instanceof Boolean)
138                        localBinding = (Boolean)((Object[])args[3])[1];
139                context.getRequestMap().put(SPRINGMVC_BINDING_ATTR, localBinding);
140                
141                Map<String, Object> valueMap = null;
142                if (args[4] instanceof InvocationCall) {
143                        valueMap = new HashMap<String, Object>();
144                        for (ContextUpdate u : ((InvocationCall)args[4]).getUpdates())
145                                valueMap.put(u.getComponentName() + (u.getExpression() != null ? "." + u.getExpression() : ""), u.getValue());
146                }
147                
148                if (grails) {
149                                // Special handling for Grails controllers
150                        try {
151                                for (Class<?> cClass : componentClasses) {
152                                        if (cClass.isInterface())
153                                                continue;
154                                        Method m = cClass.getDeclaredMethod("getProperty", String.class);
155                                        @SuppressWarnings("unchecked")
156                                        Map<String, Object> map = (Map<String, Object>)m.invoke(component, "params");
157                                        if (requestMap != null)
158                                                map.putAll(requestMap);
159                                        if (valueMap != null)
160                                                map.putAll(valueMap);
161                                }
162                        }
163                        catch (Exception e) {
164                                // Ignore, probably not a Grails controller
165                        }
166                }
167                ControllerRequestWrapper rw = new ControllerRequestWrapper(grails, context.getRequest(), componentName, (String)args[2], requestMap, valueMap);
168                return new Object[] { "handle", new Object[] { rw, context.getResponse(), handler }};
169        }
170        
171        return super.beforeMethodSearch(instance, methodName, args);
172    }
173    
174    
175    @Override
176    public void prepareCall(ServiceInvocationContext context, IInvocationCall c, String componentName, Class<?> componentClass) {
177        super.prepareCall(context, c, componentName, componentClass);
178                
179        if (componentName == null)
180                return;
181        
182                Object component = findComponent(componentName, componentClass);
183                
184                if (context.getBean() instanceof HandlerAdapter) {
185                        // In case of Spring controllers, call interceptors
186                ApplicationContext webContext = getSpringContext();
187                String[] interceptorNames = webContext.getBeanNamesForType(HandlerInterceptor.class);
188                String[] webRequestInterceptors = webContext.getBeanNamesForType(WebRequestInterceptor.class);
189                HandlerInterceptor[] interceptors = new HandlerInterceptor[interceptorNames.length+webRequestInterceptors.length];
190        
191                int j = 0;
192                for (int i = 0; i < webRequestInterceptors.length; i++)
193                    interceptors[j++] = new WebRequestHandlerInterceptorAdapter((WebRequestInterceptor)webContext.getBean(webRequestInterceptors[i]));
194                for (int i = 0; i < interceptorNames.length; i++)
195                    interceptors[j++] = (HandlerInterceptor)webContext.getBean(interceptorNames[i]);
196                
197                        HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
198                        
199                        graniteContext.getRequestMap().put(HandlerInterceptor.class.getName(), interceptors);
200                        
201                        try {
202                        for (int i = 0; i < interceptors.length; i++) {
203                            HandlerInterceptor interceptor = interceptors[i];
204                            interceptor.preHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component);
205                        }
206                        }
207                        catch (Exception e) {
208                                throw new ServiceException(e.getMessage(), e);
209                        }
210                }
211    }
212    
213
214    @Override
215    @SuppressWarnings("unchecked")
216    public IInvocationResult postCall(ServiceInvocationContext context, Object result, String componentName, Class<?> componentClass) {
217        List<ContextUpdate> results = null;
218        
219        Object component = null;
220        if (componentName != null && context.getBean() instanceof HandlerAdapter) {
221                        component = findComponent(componentName, componentClass);
222                        
223                        HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
224                        
225                Map<String, Object> modelMap = null;
226                if (result instanceof ModelAndView) {
227                        ModelAndView modelAndView = (ModelAndView)result;
228                        modelMap = modelAndView.getModel();
229                        result = modelAndView.getViewName();
230                        
231                        if (context.getBean() instanceof HandlerAdapter) {
232                                try {
233                                        HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
234                                        
235                                        if (interceptors != null) {
236                                        for (int i = interceptors.length-1; i >= 0; i--) {
237                                            HandlerInterceptor interceptor = interceptors[i];
238                                            interceptor.postHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component, modelAndView);
239                                        }
240        
241                                            triggerAfterCompletion(component, interceptors.length-1, interceptors, graniteContext.getRequest(), graniteContext.getResponse(), null);
242                                        }
243                                }
244                                catch (Exception e) {
245                                        throw new ServiceException(e.getMessage(), e);
246                                }
247                        }
248                }
249                
250                if (modelMap != null) {
251                        Boolean localBinding = (Boolean)graniteContext.getRequestMap().get(SPRINGMVC_BINDING_ATTR);
252                        
253                        results = new ArrayList<ContextUpdate>();
254                        for (Map.Entry<String, Object> me : modelMap.entrySet()) {
255                                        if (me.getKey().toString().startsWith("org.springframework.validation.")
256                                                        || (me.getValue() != null && (
257                                                                        me.getValue().getClass().getName().startsWith("groovy.lang.ExpandoMetaClass")
258                                                                        || me.getValue().getClass().getName().indexOf("$_closure") > 0
259                                                                        || me.getValue() instanceof Class)))                                                                    
260                                                continue;
261                                        String variableName = me.getKey().toString();
262                                        if (Boolean.TRUE.equals(localBinding))
263                                                results.add(new ContextUpdate(componentName, variableName, me.getValue(), 3, false));
264                                        else
265                                                results.add(new ContextUpdate(variableName, null, me.getValue(), 3, false));
266                        }
267                }
268                
269                        boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
270                        if (grails) {
271                                // Special handling for Grails controllers: get flash content
272                                try {
273                                Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
274                                for (Class<?> cClass : componentClasses) {
275                                        if (cClass.isInterface())
276                                                continue;
277                                                Method m = cClass.getDeclaredMethod("getProperty", String.class);
278                                                Map<String, Object> map = (Map<String, Object>)m.invoke(component, "flash");
279                                                if (results == null)
280                                                        results = new ArrayList<ContextUpdate>();
281                                                for (Map.Entry<String, Object> me : map.entrySet()) {
282                                                        Object value = me.getValue();
283                                                        if (value != null && value.getClass().getName().startsWith("org.codehaus.groovy.runtime.GString"))
284                                                                value = value.toString();
285                                                        results.add(new ContextUpdate("flash", me.getKey(), value, 3, false));
286                                                }
287                                }
288                                }
289                                catch (Exception e) {
290                                        throw new ServiceException("Flash scope retrieval failed", e);
291                                }
292                        }
293        }
294                
295        DataContext dataContext = DataContext.get();
296                Object[][] updates = dataContext != null ? dataContext.getUpdates() : null;
297                
298        InvocationResult ires = new InvocationResult(result, results);
299        if (component == null)
300                component = context.getBean();
301        if (isBeanAnnotationPresent(component, BypassTideMerge.class))
302                ires.setMerge(false);
303        else if (!(context.getParameters().length > 0 && context.getParameters()[0] instanceof ControllerRequestWrapper)) {
304                if (isBeanMethodAnnotationPresent(component, context.getMethod().getName(), context.getMethod().getParameterTypes(), BypassTideMerge.class))
305                        ires.setMerge(false);
306        }
307        
308        ires.setUpdates(updates);
309        
310        return ires;
311    }
312
313    @Override
314    public void postCallFault(ServiceInvocationContext context, Throwable t, String componentName, Class<?> componentClass) {
315        if (componentName != null && context.getBean() instanceof HandlerAdapter) {
316                        HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
317                        
318                        Object component = findComponent(componentName, componentClass);
319                        
320                        HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
321        
322                triggerAfterCompletion(component, interceptors.length-1, interceptors, 
323                                graniteContext.getRequest(), graniteContext.getResponse(), 
324                                t instanceof Exception ? (Exception)t : null);
325        }
326                
327        super.postCallFault(context, t, componentName, componentClass);
328    }
329
330    
331    private void triggerAfterCompletion(Object component, int interceptorIndex, HandlerInterceptor[] interceptors, HttpServletRequest request, HttpServletResponse response, Exception ex) {
332                for (int i = interceptorIndex; i >= 0; i--) {
333                        HandlerInterceptor interceptor = interceptors[i];
334                        try {
335                                interceptor.afterCompletion(request, response, component, ex);
336                        }
337                        catch (Throwable ex2) {
338                                log.error("HandlerInterceptor.afterCompletion threw exception", ex2);
339                        }
340                }
341    }
342
343    
344    
345    private class ControllerRequestWrapper extends HttpServletRequestWrapper {
346        private String componentName = null;
347        private String methodName = null;
348        private Map<String, Object> requestMap = null;
349        private Map<String, Object> valueMap = null;
350        
351                public ControllerRequestWrapper(boolean grails, HttpServletRequest request, String componentName, String methodName, Map<String, Object> requestMap, Map<String, Object> valueMap) {
352                        super(request);
353                        this.componentName = componentName.substring(0, componentName.length()-"Controller".length());
354                        if (this.componentName.indexOf(".") > 0)
355                                this.componentName = this.componentName.substring(this.componentName.lastIndexOf(".")+1);
356                        if (grails)
357                                this.componentName = this.componentName.substring(0, 1).toLowerCase() + this.componentName.substring(1);
358                        this.methodName = methodName;
359                        this.requestMap = requestMap;
360                        this.valueMap = valueMap;
361                }
362        
363                @Override
364                public String getRequestURI() {
365                        return getContextPath() + "/" + componentName + "/" + methodName;
366                }
367                
368                @Override
369                public String getServletPath() {
370                        return "/" + componentName + "/" + methodName;
371                }
372                
373                public Object getRequestValue(String key) {
374                        return requestMap != null ? requestMap.get(key) : null;
375                }
376                
377                public Object getBindValue(String key) {
378                        return valueMap != null ? valueMap.get(key) : null;
379                }
380                
381                @Override
382                public String getParameter(String name) {
383                        return requestMap != null && requestMap.containsKey(name) ? REQUEST_VALUE : null;
384                }
385                
386                @Override
387                public String[] getParameterValues(String name) {
388                        return requestMap != null && requestMap.containsKey(name) ? new String[] { REQUEST_VALUE } : null;
389                }
390
391                @Override
392                @SuppressWarnings({ "unchecked", "rawtypes" })
393                public Map getParameterMap() {
394                        Map<String, Object> pmap = new HashMap<String, Object>();
395                        if (requestMap != null) {
396                                for (String name : requestMap.keySet())
397                                        pmap.put(name, REQUEST_VALUE);
398                        }
399                        return pmap; 
400                }
401
402                @Override
403                @SuppressWarnings({ "unchecked", "rawtypes" })
404                public Enumeration getParameterNames() {
405                        Hashtable ht = new Hashtable();
406                        if (requestMap != null)
407                                ht.putAll(requestMap);
408                        return ht.keys();
409                }
410    }
411    
412    
413    private class ControllerRequestDataBinder extends ServletRequestDataBinder {
414        
415        private ControllerRequestWrapper wrapper = null;
416        private Object target = null;
417
418                public ControllerRequestDataBinder(ServletRequest request, Object target, String objectName) {
419                        super(target, objectName);
420                        this.wrapper = (ControllerRequestWrapper)request;
421                        this.target = target;
422                }
423                
424                private Object getBindValue(boolean request, Class<?> requiredType) {
425                        GraniteContext context = GraniteContext.getCurrentInstance();
426                        ClassGetter classGetter = context.getGraniteConfig().getClassGetter();
427                        Object value = request ? wrapper.getRequestValue(getObjectName()) : wrapper.getBindValue(getObjectName());
428                        if (requiredType != null) {
429                                Converter converter = context.getGraniteConfig().getConverters().getConverter(value, requiredType);
430                                if (converter != null)
431                                        value = converter.convert(value, requiredType);
432                        }
433                        if (value != null && !request)
434                                return SpringMVCServiceContext.this.mergeExternal(classGetter, value, null, null, null);
435                        return value;
436                }
437
438                @Override
439                public void bind(ServletRequest request) {
440                        Object value = getBindValue(false, null);
441                        if (value != null)
442                                target = value;
443                }
444                
445                @Override
446                public Object getTarget() {
447                        return target;                  
448                }
449                
450                @SuppressWarnings({ "rawtypes", "unchecked" })
451                @Override
452                public Object convertIfNecessary(Object value, Class requiredType, MethodParameter methodParam) throws TypeMismatchException {
453                        if (target == null && value == REQUEST_VALUE || (value instanceof String[] && ((String[])value)[0] == REQUEST_VALUE))
454                                return getBindValue(true, requiredType);
455                        return super.convertIfNecessary(value, requiredType, methodParam);
456                }
457    }
458}