001    /*
002      GRANITE DATA SERVICES
003      Copyright (C) 2011 GRANITE DATA SERVICES S.A.S.
004    
005      This file is part of Granite Data Services.
006    
007      Granite Data Services is free software; you can redistribute it and/or modify
008      it under the terms of the GNU Library General Public License as published by
009      the Free Software Foundation; either version 2 of the License, or (at your
010      option) any later version.
011    
012      Granite Data Services is distributed in the hope that it will be useful, but
013      WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
014      FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License
015      for more details.
016    
017      You should have received a copy of the GNU Library General Public License
018      along with this library; if not, see <http://www.gnu.org/licenses/>.
019    */
020    
021    package org.granite.tide.spring;
022    
023    import java.lang.reflect.Method;
024    import java.util.ArrayList;
025    import java.util.Enumeration;
026    import java.util.HashMap;
027    import java.util.Hashtable;
028    import java.util.List;
029    import java.util.Map;
030    import java.util.Set;
031    
032    import javax.servlet.ServletRequest;
033    import javax.servlet.http.HttpServletRequest;
034    import javax.servlet.http.HttpServletRequestWrapper;
035    import javax.servlet.http.HttpServletResponse;
036    
037    import org.granite.context.GraniteContext;
038    import org.granite.logging.Logger;
039    import org.granite.messaging.amf.io.convert.Converter;
040    import org.granite.messaging.amf.io.util.ClassGetter;
041    import org.granite.messaging.service.ServiceException;
042    import org.granite.messaging.service.ServiceInvocationContext;
043    import org.granite.messaging.webapp.HttpGraniteContext;
044    import org.granite.tide.IInvocationCall;
045    import org.granite.tide.IInvocationResult;
046    import org.granite.tide.annotations.BypassTideMerge;
047    import org.granite.tide.data.DataContext;
048    import org.granite.tide.invocation.ContextUpdate;
049    import org.granite.tide.invocation.InvocationCall;
050    import org.granite.tide.invocation.InvocationResult;
051    import org.granite.util.TypeUtil;
052    import org.springframework.beans.TypeMismatchException;
053    import org.springframework.context.ApplicationContext;
054    import org.springframework.core.MethodParameter;
055    import org.springframework.web.bind.ServletRequestDataBinder;
056    import org.springframework.web.context.request.WebRequestInterceptor;
057    import org.springframework.web.servlet.HandlerAdapter;
058    import org.springframework.web.servlet.HandlerInterceptor;
059    import org.springframework.web.servlet.ModelAndView;
060    import org.springframework.web.servlet.handler.WebRequestHandlerInterceptorAdapter;
061    import org.springframework.web.servlet.mvc.Controller;
062    import org.springframework.web.servlet.mvc.SimpleControllerHandlerAdapter;
063    import org.springframework.web.servlet.mvc.annotation.AnnotationMethodHandlerAdapter;
064    
065    
066    /**
067     * @author William DRAI
068     */
069    public class SpringMVCServiceContext extends SpringServiceContext {
070    
071        private static final long serialVersionUID = 1L;
072        
073        private static final String REQUEST_VALUE = "__REQUEST_VALUE__";
074        
075        private static final Logger log = Logger.getLogger(SpringMVCServiceContext.class);
076    
077        
078        public SpringMVCServiceContext() throws ServiceException {
079            super();
080        }
081        
082        public SpringMVCServiceContext(ApplicationContext springContext) throws ServiceException {
083            super(springContext);
084        }
085        
086        
087        @Override
088        public Object adjustInvokee(Object instance, String componentName, Set<Class<?>> componentClasses) {
089            for (Class<?> componentClass : componentClasses) {
090                    if (componentClass.isAnnotationPresent(org.springframework.stereotype.Controller.class)) {
091                            return new AnnotationMethodHandlerAdapter() {
092                                    @Override
093                                    protected ServletRequestDataBinder createBinder(HttpServletRequest request, Object target, String objectName) throws Exception {
094                                            return new ControllerRequestDataBinder(request, target, objectName);
095                                    }
096                            };
097                    }
098            }
099            if (Controller.class.isInstance(instance) || (componentName != null && componentName.endsWith("Controller")))
100                    return new SimpleControllerHandlerAdapter();
101            
102            return instance;
103        }
104        
105        
106        private static final String SPRINGMVC_BINDING_ATTR = "__SPRINGMVC_LOCAL_BINDING__";
107        
108        @Override
109        public Object[] beforeMethodSearch(Object instance, String methodName, Object[] args) {
110            if (instance instanceof HandlerAdapter) {
111                    boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
112                    
113                    String componentName = (String)args[0];
114                    String componentClassName = (String)args[1];
115                Class<?> componentClass = null;
116                try {
117                    if (componentClassName != null)
118                            componentClass = TypeUtil.forName(componentClassName);
119                }
120                catch (ClassNotFoundException e) {
121                    throw new ServiceException("Component class not found " + componentClassName, e);
122                }
123                    Object component = findComponent(componentName, componentClass);
124                    Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
125                    Object handler = component;
126                    if (grails && componentName.endsWith("Controller")) {
127                            // Special handling for Grails controllers
128                            handler = springContext.getBean("mainSimpleController");
129                    }
130                    HttpGraniteContext context = (HttpGraniteContext)GraniteContext.getCurrentInstance();
131                    @SuppressWarnings("unchecked")
132                    Map<String, Object> requestMap = (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length >= 1 && ((Object[])args[3]).length <= 2 && ((Object[])args[3])[0] instanceof Map) 
133                            ? (Map<String, Object>)((Object[])args[3])[0] 
134                            : null;
135                    boolean localBinding = false;
136                    if (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length == 2 
137                                    && ((Object[])args[3])[0] instanceof Map<?, ?> && ((Object[])args[3])[1] instanceof Boolean)
138                            localBinding = (Boolean)((Object[])args[3])[1];
139                    context.getRequestMap().put(SPRINGMVC_BINDING_ATTR, localBinding);
140                    
141                    Map<String, Object> valueMap = null;
142                    if (args[4] instanceof InvocationCall) {
143                            valueMap = new HashMap<String, Object>();
144                            for (ContextUpdate u : ((InvocationCall)args[4]).getUpdates())
145                                    valueMap.put(u.getComponentName() + (u.getExpression() != null ? "." + u.getExpression() : ""), u.getValue());
146                    }
147                    
148                    if (grails) {
149                                    // Special handling for Grails controllers
150                            try {
151                                    for (Class<?> cClass : componentClasses) {
152                                            if (cClass.isInterface())
153                                                    continue;
154                                            Method m = cClass.getDeclaredMethod("getProperty", String.class);
155                                            @SuppressWarnings("unchecked")
156                                            Map<String, Object> map = (Map<String, Object>)m.invoke(component, "params");
157                                            if (requestMap != null)
158                                                    map.putAll(requestMap);
159                                            if (valueMap != null)
160                                                    map.putAll(valueMap);
161                                    }
162                            }
163                            catch (Exception e) {
164                                    // Ignore, probably not a Grails controller
165                            }
166                    }
167                    ControllerRequestWrapper rw = new ControllerRequestWrapper(grails, context.getRequest(), componentName, (String)args[2], requestMap, valueMap);
168                    return new Object[] { "handle", new Object[] { rw, context.getResponse(), handler }};
169            }
170            
171            return super.beforeMethodSearch(instance, methodName, args);
172        }
173        
174        
175        @Override
176        public void prepareCall(ServiceInvocationContext context, IInvocationCall c, String componentName, Class<?> componentClass) {
177            super.prepareCall(context, c, componentName, componentClass);
178                    
179            if (componentName == null)
180                    return;
181            
182                    Object component = findComponent(componentName, componentClass);
183                    
184                    if (context.getBean() instanceof HandlerAdapter) {
185                            // In case of Spring controllers, call interceptors
186                    ApplicationContext webContext = getSpringContext();
187                    String[] interceptorNames = webContext.getBeanNamesForType(HandlerInterceptor.class);
188                    String[] webRequestInterceptors = webContext.getBeanNamesForType(WebRequestInterceptor.class);
189                    HandlerInterceptor[] interceptors = new HandlerInterceptor[interceptorNames.length+webRequestInterceptors.length];
190            
191                    int j = 0;
192                    for (int i = 0; i < webRequestInterceptors.length; i++)
193                        interceptors[j++] = new WebRequestHandlerInterceptorAdapter((WebRequestInterceptor)webContext.getBean(webRequestInterceptors[i]));
194                    for (int i = 0; i < interceptorNames.length; i++)
195                        interceptors[j++] = (HandlerInterceptor)webContext.getBean(interceptorNames[i]);
196                    
197                            HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
198                            
199                            graniteContext.getRequestMap().put(HandlerInterceptor.class.getName(), interceptors);
200                            
201                            try {
202                            for (int i = 0; i < interceptors.length; i++) {
203                                HandlerInterceptor interceptor = interceptors[i];
204                                interceptor.preHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component);
205                            }
206                            }
207                            catch (Exception e) {
208                                    throw new ServiceException(e.getMessage(), e);
209                            }
210                    }
211        }
212        
213    
214        @Override
215        @SuppressWarnings("unchecked")
216        public IInvocationResult postCall(ServiceInvocationContext context, Object result, String componentName, Class<?> componentClass) {
217            List<ContextUpdate> results = null;
218            
219            Object component = null;
220            if (componentName != null && context.getBean() instanceof HandlerAdapter) {
221                            component = findComponent(componentName, componentClass);
222                            
223                            HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
224                            
225                    Map<String, Object> modelMap = null;
226                    if (result instanceof ModelAndView) {
227                            ModelAndView modelAndView = (ModelAndView)result;
228                            modelMap = modelAndView.getModel();
229                            result = modelAndView.getViewName();
230                            
231                            if (context.getBean() instanceof HandlerAdapter) {
232                                    try {
233                                            HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
234                                            
235                                            if (interceptors != null) {
236                                            for (int i = interceptors.length-1; i >= 0; i--) {
237                                                HandlerInterceptor interceptor = interceptors[i];
238                                                interceptor.postHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component, modelAndView);
239                                            }
240            
241                                                triggerAfterCompletion(component, interceptors.length-1, interceptors, graniteContext.getRequest(), graniteContext.getResponse(), null);
242                                            }
243                                    }
244                                    catch (Exception e) {
245                                            throw new ServiceException(e.getMessage(), e);
246                                    }
247                            }
248                    }
249                    
250                    if (modelMap != null) {
251                            Boolean localBinding = (Boolean)graniteContext.getRequestMap().get(SPRINGMVC_BINDING_ATTR);
252                            
253                            results = new ArrayList<ContextUpdate>();
254                            for (Map.Entry<String, Object> me : modelMap.entrySet()) {
255                                            if (me.getKey().toString().startsWith("org.springframework.validation.")
256                                                            || (me.getValue() != null && (
257                                                                            me.getValue().getClass().getName().startsWith("groovy.lang.ExpandoMetaClass")
258                                                                            || me.getValue().getClass().getName().indexOf("$_closure") > 0
259                                                                            || me.getValue() instanceof Class)))                                                                    
260                                                    continue;
261                                            String variableName = me.getKey().toString();
262                                            if (Boolean.TRUE.equals(localBinding))
263                                                    results.add(new ContextUpdate(componentName, variableName, me.getValue(), 3, false));
264                                            else
265                                                    results.add(new ContextUpdate(variableName, null, me.getValue(), 3, false));
266                            }
267                    }
268                    
269                            boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
270                            if (grails) {
271                                    // Special handling for Grails controllers: get flash content
272                                    try {
273                                    Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
274                                    for (Class<?> cClass : componentClasses) {
275                                            if (cClass.isInterface())
276                                                    continue;
277                                                    Method m = cClass.getDeclaredMethod("getProperty", String.class);
278                                                    Map<String, Object> map = (Map<String, Object>)m.invoke(component, "flash");
279                                                    if (results == null)
280                                                            results = new ArrayList<ContextUpdate>();
281                                                    for (Map.Entry<String, Object> me : map.entrySet()) {
282                                                            Object value = me.getValue();
283                                                            if (value != null && value.getClass().getName().startsWith("org.codehaus.groovy.runtime.GString"))
284                                                                    value = value.toString();
285                                                            results.add(new ContextUpdate("flash", me.getKey(), value, 3, false));
286                                                    }
287                                    }
288                                    }
289                                    catch (Exception e) {
290                                            throw new ServiceException("Flash scope retrieval failed", e);
291                                    }
292                            }
293            }
294                    
295            DataContext dataContext = DataContext.get();
296                    Object[][] updates = dataContext != null ? dataContext.getUpdates() : null;
297                    
298            InvocationResult ires = new InvocationResult(result, results);
299            if (component == null)
300                    component = context.getBean();
301            if (isBeanAnnotationPresent(component, BypassTideMerge.class))
302                    ires.setMerge(false);
303            else if (!(context.getParameters().length > 0 && context.getParameters()[0] instanceof ControllerRequestWrapper)) {
304                    if (isBeanMethodAnnotationPresent(component, context.getMethod().getName(), context.getMethod().getParameterTypes(), BypassTideMerge.class))
305                            ires.setMerge(false);
306            }
307            
308            ires.setUpdates(updates);
309            
310            return ires;
311        }
312    
313        @Override
314        public void postCallFault(ServiceInvocationContext context, Throwable t, String componentName, Class<?> componentClass) {
315            if (componentName != null && context.getBean() instanceof HandlerAdapter) {
316                            HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
317                            
318                            Object component = findComponent(componentName, componentClass);
319                            
320                            HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
321            
322                    triggerAfterCompletion(component, interceptors.length-1, interceptors, 
323                                    graniteContext.getRequest(), graniteContext.getResponse(), 
324                                    t instanceof Exception ? (Exception)t : null);
325            }
326                    
327            super.postCallFault(context, t, componentName, componentClass);
328        }
329    
330        
331        private void triggerAfterCompletion(Object component, int interceptorIndex, HandlerInterceptor[] interceptors, HttpServletRequest request, HttpServletResponse response, Exception ex) {
332                    for (int i = interceptorIndex; i >= 0; i--) {
333                            HandlerInterceptor interceptor = interceptors[i];
334                            try {
335                                    interceptor.afterCompletion(request, response, component, ex);
336                            }
337                            catch (Throwable ex2) {
338                                    log.error("HandlerInterceptor.afterCompletion threw exception", ex2);
339                            }
340                    }
341        }
342    
343        
344        
345        private class ControllerRequestWrapper extends HttpServletRequestWrapper {
346            private String componentName = null;
347            private String methodName = null;
348            private Map<String, Object> requestMap = null;
349            private Map<String, Object> valueMap = null;
350            
351                    public ControllerRequestWrapper(boolean grails, HttpServletRequest request, String componentName, String methodName, Map<String, Object> requestMap, Map<String, Object> valueMap) {
352                            super(request);
353                            this.componentName = componentName.substring(0, componentName.length()-"Controller".length());
354                            if (this.componentName.indexOf(".") > 0)
355                                    this.componentName = this.componentName.substring(this.componentName.lastIndexOf(".")+1);
356                            if (grails)
357                                    this.componentName = this.componentName.substring(0, 1).toLowerCase() + this.componentName.substring(1);
358                            this.methodName = methodName;
359                            this.requestMap = requestMap;
360                            this.valueMap = valueMap;
361                    }
362            
363                    @Override
364                    public String getRequestURI() {
365                            return getContextPath() + "/" + componentName + "/" + methodName;
366                    }
367                    
368                    @Override
369                    public String getServletPath() {
370                            return "/" + componentName + "/" + methodName;
371                    }
372                    
373                    public Object getRequestValue(String key) {
374                            return requestMap != null ? requestMap.get(key) : null;
375                    }
376                    
377                    public Object getBindValue(String key) {
378                            return valueMap != null ? valueMap.get(key) : null;
379                    }
380                    
381                    @Override
382                    public String getParameter(String name) {
383                            return requestMap != null && requestMap.containsKey(name) ? REQUEST_VALUE : null;
384                    }
385                    
386                    @Override
387                    public String[] getParameterValues(String name) {
388                            return requestMap != null && requestMap.containsKey(name) ? new String[] { REQUEST_VALUE } : null;
389                    }
390    
391                    @Override
392                    @SuppressWarnings({ "unchecked", "rawtypes" })
393                    public Map getParameterMap() {
394                            Map<String, Object> pmap = new HashMap<String, Object>();
395                            if (requestMap != null) {
396                                    for (String name : requestMap.keySet())
397                                            pmap.put(name, REQUEST_VALUE);
398                            }
399                            return pmap; 
400                    }
401    
402                    @Override
403                    @SuppressWarnings({ "unchecked", "rawtypes" })
404                    public Enumeration getParameterNames() {
405                            Hashtable ht = new Hashtable();
406                            if (requestMap != null)
407                                    ht.putAll(requestMap);
408                            return ht.keys();
409                    }
410        }
411        
412        
413        private class ControllerRequestDataBinder extends ServletRequestDataBinder {
414            
415            private ControllerRequestWrapper wrapper = null;
416            private Object target = null;
417    
418                    public ControllerRequestDataBinder(ServletRequest request, Object target, String objectName) {
419                            super(target, objectName);
420                            this.wrapper = (ControllerRequestWrapper)request;
421                            this.target = target;
422                    }
423                    
424                    private Object getBindValue(boolean request, Class<?> requiredType) {
425                            GraniteContext context = GraniteContext.getCurrentInstance();
426                            ClassGetter classGetter = context.getGraniteConfig().getClassGetter();
427                            Object value = request ? wrapper.getRequestValue(getObjectName()) : wrapper.getBindValue(getObjectName());
428                            if (requiredType != null) {
429                                    Converter converter = context.getGraniteConfig().getConverters().getConverter(value, requiredType);
430                                    if (converter != null)
431                                            value = converter.convert(value, requiredType);
432                            }
433                            if (value != null && !request)
434                                    return SpringMVCServiceContext.this.mergeExternal(classGetter, value, null, null, null);
435                            return value;
436                    }
437    
438                    @Override
439                    public void bind(ServletRequest request) {
440                            Object value = getBindValue(false, null);
441                            if (value != null)
442                                    target = value;
443                    }
444                    
445                    @Override
446                    public Object getTarget() {
447                            return target;                  
448                    }
449                    
450                    @SuppressWarnings({ "rawtypes", "unchecked" })
451                    @Override
452                    public Object convertIfNecessary(Object value, Class requiredType, MethodParameter methodParam) throws TypeMismatchException {
453                            if (target == null && value == REQUEST_VALUE || (value instanceof String[] && ((String[])value)[0] == REQUEST_VALUE))
454                                    return getBindValue(true, requiredType);
455                            return super.convertIfNecessary(value, requiredType, methodParam);
456                    }
457        }
458    }