001    /*
002      GRANITE DATA SERVICES
003      Copyright (C) 2011 GRANITE DATA SERVICES S.A.S.
004    
005      This file is part of Granite Data Services.
006    
007      Granite Data Services is free software; you can redistribute it and/or modify
008      it under the terms of the GNU Library General Public License as published by
009      the Free Software Foundation; either version 2 of the License, or (at your
010      option) any later version.
011    
012      Granite Data Services is distributed in the hope that it will be useful, but
013      WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
014      FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License
015      for more details.
016    
017      You should have received a copy of the GNU Library General Public License
018      along with this library; if not, see <http://www.gnu.org/licenses/>.
019    */
020    
021    package org.granite.tide.spring;
022    
023    import java.lang.reflect.Method;
024    import java.util.ArrayList;
025    import java.util.Enumeration;
026    import java.util.HashMap;
027    import java.util.Hashtable;
028    import java.util.List;
029    import java.util.Map;
030    import java.util.Set;
031    
032    import javax.servlet.ServletRequest;
033    import javax.servlet.http.HttpServletRequest;
034    import javax.servlet.http.HttpServletRequestWrapper;
035    import javax.servlet.http.HttpServletResponse;
036    
037    import org.granite.context.GraniteContext;
038    import org.granite.logging.Logger;
039    import org.granite.messaging.amf.io.convert.Converter;
040    import org.granite.messaging.amf.io.util.ClassGetter;
041    import org.granite.messaging.service.ServiceException;
042    import org.granite.messaging.service.ServiceInvocationContext;
043    import org.granite.messaging.webapp.HttpGraniteContext;
044    import org.granite.tide.IInvocationCall;
045    import org.granite.tide.IInvocationResult;
046    import org.granite.tide.TidePersistenceManager;
047    import org.granite.tide.annotations.BypassTideMerge;
048    import org.granite.tide.data.DataContext;
049    import org.granite.tide.invocation.ContextUpdate;
050    import org.granite.tide.invocation.InvocationCall;
051    import org.granite.tide.invocation.InvocationResult;
052    import org.granite.util.ClassUtil;
053    import org.springframework.beans.TypeMismatchException;
054    import org.springframework.context.ApplicationContext;
055    import org.springframework.core.MethodParameter;
056    import org.springframework.web.bind.ServletRequestDataBinder;
057    import org.springframework.web.context.request.WebRequestInterceptor;
058    import org.springframework.web.servlet.HandlerAdapter;
059    import org.springframework.web.servlet.HandlerInterceptor;
060    import org.springframework.web.servlet.ModelAndView;
061    import org.springframework.web.servlet.handler.WebRequestHandlerInterceptorAdapter;
062    import org.springframework.web.servlet.mvc.Controller;
063    import org.springframework.web.servlet.mvc.SimpleControllerHandlerAdapter;
064    import org.springframework.web.servlet.mvc.annotation.AnnotationMethodHandlerAdapter;
065    
066    
067    /**
068     * @author William DRAI
069     */
070    public class SpringMVCServiceContext extends SpringServiceContext {
071    
072        private static final long serialVersionUID = 1L;
073        
074        private static final String REQUEST_VALUE = "__REQUEST_VALUE__";
075        
076        private static final Logger log = Logger.getLogger(SpringMVCServiceContext.class);
077    
078        
079        public SpringMVCServiceContext() throws ServiceException {
080            super();
081        }
082        
083        public SpringMVCServiceContext(ApplicationContext springContext) throws ServiceException {
084            super(springContext);
085        }
086        
087        
088        @Override
089        public Object adjustInvokee(Object instance, String componentName, Set<Class<?>> componentClasses) {
090            for (Class<?> componentClass : componentClasses) {
091                    if (componentClass.isAnnotationPresent(org.springframework.stereotype.Controller.class)) {
092                            return new AnnotationMethodHandlerAdapter() {
093                                    @Override
094                                    protected ServletRequestDataBinder createBinder(HttpServletRequest request, Object target, String objectName) throws Exception {
095                                            return new ControllerRequestDataBinder(request, target, objectName);
096                                    }
097                            };
098                    }
099            }
100            if (Controller.class.isInstance(instance) || (componentName != null && componentName.endsWith("Controller")))
101                    return new SimpleControllerHandlerAdapter();
102            
103            return instance;
104        }
105        
106        
107        private static final String SPRINGMVC_BINDING_ATTR = "__SPRINGMVC_LOCAL_BINDING__";
108        
109        @Override
110        public Object[] beforeMethodSearch(Object instance, String methodName, Object[] args) {
111            if (instance instanceof HandlerAdapter) {
112                    boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
113                    
114                    String componentName = (String)args[0];
115                    String componentClassName = (String)args[1];
116                Class<?> componentClass = null;
117                try {
118                    if (componentClassName != null)
119                            componentClass = ClassUtil.forName(componentClassName);
120                }
121                catch (ClassNotFoundException e) {
122                    throw new ServiceException("Component class not found " + componentClassName, e);
123                }
124                    Object component = findComponent(componentName, componentClass);
125                    Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
126                    Object handler = component;
127                    if (grails && componentName.endsWith("Controller")) {
128                            // Special handling for Grails controllers
129                            handler = springContext.getBean("mainSimpleController");
130                    }
131                    HttpGraniteContext context = (HttpGraniteContext)GraniteContext.getCurrentInstance();
132                    @SuppressWarnings("unchecked")
133                    Map<String, Object> requestMap = (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length >= 1 && ((Object[])args[3]).length <= 2 && ((Object[])args[3])[0] instanceof Map) 
134                            ? (Map<String, Object>)((Object[])args[3])[0] 
135                            : null;
136                    boolean localBinding = false;
137                    if (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length == 2 
138                                    && ((Object[])args[3])[0] instanceof Map<?, ?> && ((Object[])args[3])[1] instanceof Boolean)
139                            localBinding = (Boolean)((Object[])args[3])[1];
140                    context.getRequestMap().put(SPRINGMVC_BINDING_ATTR, localBinding);
141                    
142                    Map<String, Object> valueMap = null;
143                    if (args[4] instanceof InvocationCall) {
144                            valueMap = new HashMap<String, Object>();
145                            for (ContextUpdate u : ((InvocationCall)args[4]).getUpdates())
146                                    valueMap.put(u.getComponentName() + (u.getExpression() != null ? "." + u.getExpression() : ""), u.getValue());
147                    }
148                    
149                    if (grails) {
150                                    // Special handling for Grails controllers
151                            try {
152                                    for (Class<?> cClass : componentClasses) {
153                                            if (cClass.isInterface())
154                                                    continue;
155                                            Method m = cClass.getDeclaredMethod("getProperty", String.class);
156                                            @SuppressWarnings("unchecked")
157                                            Map<String, Object> map = (Map<String, Object>)m.invoke(component, "params");
158                                            if (requestMap != null)
159                                                    map.putAll(requestMap);
160                                            if (valueMap != null)
161                                                    map.putAll(valueMap);
162                                    }
163                            }
164                            catch (Exception e) {
165                                    // Ignore, probably not a Grails controller
166                            }
167                    }
168                    ControllerRequestWrapper rw = new ControllerRequestWrapper(grails, context.getRequest(), componentName, (String)args[2], requestMap, valueMap);
169                    return new Object[] { "handle", new Object[] { rw, context.getResponse(), handler }};
170            }
171            
172            return super.beforeMethodSearch(instance, methodName, args);
173        }
174        
175        
176        @Override
177        public void prepareCall(ServiceInvocationContext context, IInvocationCall c, String componentName, Class<?> componentClass) {
178            super.prepareCall(context, c, componentName, componentClass);
179                    
180            if (componentName == null)
181                    return;
182            
183                    Object component = findComponent(componentName, componentClass);
184                    
185                    if (context.getBean() instanceof HandlerAdapter) {
186                            // In case of Spring controllers, call interceptors
187                    ApplicationContext webContext = getSpringContext();
188                    String[] interceptorNames = webContext.getBeanNamesForType(HandlerInterceptor.class);
189                    String[] webRequestInterceptors = webContext.getBeanNamesForType(WebRequestInterceptor.class);
190                    HandlerInterceptor[] interceptors = new HandlerInterceptor[interceptorNames.length+webRequestInterceptors.length];
191            
192                    int j = 0;
193                    for (int i = 0; i < webRequestInterceptors.length; i++)
194                        interceptors[j++] = new WebRequestHandlerInterceptorAdapter((WebRequestInterceptor)webContext.getBean(webRequestInterceptors[i]));
195                    for (int i = 0; i < interceptorNames.length; i++)
196                        interceptors[j++] = (HandlerInterceptor)webContext.getBean(interceptorNames[i]);
197                    
198                            HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
199                            
200                            graniteContext.getRequestMap().put(HandlerInterceptor.class.getName(), interceptors);
201                            
202                            try {
203                            for (int i = 0; i < interceptors.length; i++) {
204                                HandlerInterceptor interceptor = interceptors[i];
205                                interceptor.preHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component);
206                            }
207                            }
208                            catch (Exception e) {
209                                    throw new ServiceException(e.getMessage(), e);
210                            }
211                    }
212        }
213        
214    
215        @Override
216        @SuppressWarnings("unchecked")
217        public IInvocationResult postCall(ServiceInvocationContext context, Object result, String componentName, Class<?> componentClass) {
218            List<ContextUpdate> results = null;
219            
220            Object component = null;
221            if (componentName != null && context.getBean() instanceof HandlerAdapter) {
222                            component = findComponent(componentName, componentClass);
223                            
224                            HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
225                            
226                    Map<String, Object> modelMap = null;
227                    if (result instanceof ModelAndView) {
228                            ModelAndView modelAndView = (ModelAndView)result;
229                            modelMap = modelAndView.getModel();
230                            result = modelAndView.getViewName();
231                            
232                            if (context.getBean() instanceof HandlerAdapter) {
233                                    try {
234                                            HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
235                                            
236                                            if (interceptors != null) {
237                                            for (int i = interceptors.length-1; i >= 0; i--) {
238                                                HandlerInterceptor interceptor = interceptors[i];
239                                                interceptor.postHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component, modelAndView);
240                                            }
241            
242                                                triggerAfterCompletion(component, interceptors.length-1, interceptors, graniteContext.getRequest(), graniteContext.getResponse(), null);
243                                            }
244                                    }
245                                    catch (Exception e) {
246                                            throw new ServiceException(e.getMessage(), e);
247                                    }
248                            }
249                    }
250                    
251                    if (modelMap != null) {
252                            Boolean localBinding = (Boolean)graniteContext.getRequestMap().get(SPRINGMVC_BINDING_ATTR);
253                            
254                            results = new ArrayList<ContextUpdate>();
255                            for (Map.Entry<String, Object> me : modelMap.entrySet()) {
256                                            if (me.getKey().toString().startsWith("org.springframework.validation.")
257                                                            || (me.getValue() != null && (
258                                                                            me.getValue().getClass().getName().startsWith("groovy.lang.ExpandoMetaClass")
259                                                                            || me.getValue().getClass().getName().indexOf("$_closure") > 0
260                                                                            || me.getValue() instanceof Class)))                                                                    
261                                                    continue;
262                                            String variableName = me.getKey().toString();
263                                            if (Boolean.TRUE.equals(localBinding))
264                                                    results.add(new ContextUpdate(componentName, variableName, me.getValue(), 3, false));
265                                            else
266                                                    results.add(new ContextUpdate(variableName, null, me.getValue(), 3, false));
267                            }
268                    }
269                    
270                            boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
271                            if (grails) {
272                                    // Special handling for Grails controllers: get flash content
273                                    try {
274                                    Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
275                                    for (Class<?> cClass : componentClasses) {
276                                            if (cClass.isInterface())
277                                                    continue;
278                                                    Method m = cClass.getDeclaredMethod("getProperty", String.class);
279                                                    Map<String, Object> map = (Map<String, Object>)m.invoke(component, "flash");
280                                                    if (results == null)
281                                                            results = new ArrayList<ContextUpdate>();
282                                                    for (Map.Entry<String, Object> me : map.entrySet()) {
283                                                            Object value = me.getValue();
284                                                            if (value != null && value.getClass().getName().startsWith("org.codehaus.groovy.runtime.GString"))
285                                                                    value = value.toString();
286                                                            results.add(new ContextUpdate("flash", me.getKey(), value, 3, false));
287                                                    }
288                                    }
289                                    }
290                                    catch (Exception e) {
291                                            throw new ServiceException("Flash scope retrieval failed", e);
292                                    }
293                            }
294            }
295                    
296            DataContext dataContext = DataContext.get();
297                    Object[][] updates = dataContext != null ? dataContext.getUpdates() : null;
298                    
299            InvocationResult ires = new InvocationResult(result, results);
300            if (component == null)
301                    component = context.getBean();
302            if (isBeanAnnotationPresent(component, BypassTideMerge.class))
303                    ires.setMerge(false);
304            else if (!(context.getParameters().length > 0 && context.getParameters()[0] instanceof ControllerRequestWrapper)) {
305                    if (isBeanMethodAnnotationPresent(component, context.getMethod().getName(), context.getMethod().getParameterTypes(), BypassTideMerge.class))
306                            ires.setMerge(false);
307            }
308            
309            ires.setUpdates(updates);
310            
311            return ires;
312        }
313    
314        @Override
315        public void postCallFault(ServiceInvocationContext context, Throwable t, String componentName, Class<?> componentClass) {
316            if (componentName != null && context.getBean() instanceof HandlerAdapter) {
317                            HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
318                            
319                            Object component = findComponent(componentName, componentClass);
320                            
321                            HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
322            
323                    triggerAfterCompletion(component, interceptors.length-1, interceptors, 
324                                    graniteContext.getRequest(), graniteContext.getResponse(), 
325                                    t instanceof Exception ? (Exception)t : null);
326            }
327                    
328            super.postCallFault(context, t, componentName, componentClass);
329        }
330    
331        
332        private void triggerAfterCompletion(Object component, int interceptorIndex, HandlerInterceptor[] interceptors, HttpServletRequest request, HttpServletResponse response, Exception ex) {
333                    for (int i = interceptorIndex; i >= 0; i--) {
334                            HandlerInterceptor interceptor = interceptors[i];
335                            try {
336                                    interceptor.afterCompletion(request, response, component, ex);
337                            }
338                            catch (Throwable ex2) {
339                                    log.error("HandlerInterceptor.afterCompletion threw exception", ex2);
340                            }
341                    }
342        }
343    
344        
345        
346        private class ControllerRequestWrapper extends HttpServletRequestWrapper {
347            private String componentName = null;
348            private String methodName = null;
349            private Map<String, Object> requestMap = null;
350            private Map<String, Object> valueMap = null;
351            
352                    public ControllerRequestWrapper(boolean grails, HttpServletRequest request, String componentName, String methodName, Map<String, Object> requestMap, Map<String, Object> valueMap) {
353                            super(request);
354                            this.componentName = componentName.substring(0, componentName.length()-"Controller".length());
355                            if (this.componentName.indexOf(".") > 0)
356                                    this.componentName = this.componentName.substring(this.componentName.lastIndexOf(".")+1);
357                            if (grails)
358                                    this.componentName = this.componentName.substring(0, 1).toLowerCase() + this.componentName.substring(1);
359                            this.methodName = methodName;
360                            this.requestMap = requestMap;
361                            this.valueMap = valueMap;
362                    }
363            
364                    @Override
365                    public String getRequestURI() {
366                            return getContextPath() + "/" + componentName + "/" + methodName;
367                    }
368                    
369                    @Override
370                    public String getServletPath() {
371                            return "/" + componentName + "/" + methodName;
372                    }
373                    
374                    public Object getRequestValue(String key) {
375                            return requestMap != null ? requestMap.get(key) : null;
376                    }
377                    
378                    public Object getBindValue(String key) {
379                            return valueMap != null ? valueMap.get(key) : null;
380                    }
381                    
382                    @Override
383                    public String getParameter(String name) {
384                            return requestMap != null && requestMap.containsKey(name) ? REQUEST_VALUE : null;
385                    }
386                    
387                    @Override
388                    public String[] getParameterValues(String name) {
389                            return requestMap != null && requestMap.containsKey(name) ? new String[] { REQUEST_VALUE } : null;
390                    }
391    
392                    @Override
393                    @SuppressWarnings({ "unchecked", "rawtypes" })
394                    public Map getParameterMap() {
395                            Map<String, Object> pmap = new HashMap<String, Object>();
396                            if (requestMap != null) {
397                                    for (String name : requestMap.keySet())
398                                            pmap.put(name, REQUEST_VALUE);
399                            }
400                            return pmap; 
401                    }
402    
403                    @Override
404                    @SuppressWarnings({ "unchecked", "rawtypes" })
405                    public Enumeration getParameterNames() {
406                            Hashtable ht = new Hashtable();
407                            if (requestMap != null)
408                                    ht.putAll(requestMap);
409                            return ht.keys();
410                    }
411        }
412        
413        
414        private class ControllerRequestDataBinder extends ServletRequestDataBinder {
415            
416            private ControllerRequestWrapper wrapper = null;
417            private Object target = null;
418    
419                    public ControllerRequestDataBinder(ServletRequest request, Object target, String objectName) {
420                            super(target, objectName);
421                            this.wrapper = (ControllerRequestWrapper)request;
422                            this.target = target;
423                    }
424                    
425                    private Object getBindValue(boolean request, Class<?> requiredType) {
426                            GraniteContext context = GraniteContext.getCurrentInstance();
427                            TidePersistenceManager pm = SpringMVCServiceContext.this.getTidePersistenceManager(true);
428                            ClassGetter classGetter = context.getGraniteConfig().getClassGetter();
429                            Object value = request ? wrapper.getRequestValue(getObjectName()) : wrapper.getBindValue(getObjectName());
430                            if (requiredType != null) {
431                                    Converter converter = context.getGraniteConfig().getConverters().getConverter(value, requiredType);
432                                    if (converter != null)
433                                            value = converter.convert(value, requiredType);
434                            }
435                            if (value != null && !request)
436                                    return SpringMVCServiceContext.this.mergeExternal(pm, classGetter, value, null, null, null);
437                            return value;
438                    }
439    
440                    @Override
441                    public void bind(ServletRequest request) {
442                            Object value = getBindValue(false, null);
443                            if (value != null)
444                                    target = value;
445                    }
446                    
447                    @Override
448                    public Object getTarget() {
449                            return target;                  
450                    }
451                    
452                    @SuppressWarnings({ "rawtypes", "unchecked" })
453                    @Override
454                    public Object convertIfNecessary(Object value, Class requiredType, MethodParameter methodParam) throws TypeMismatchException {
455                            if (target == null && value == REQUEST_VALUE || (value instanceof String[] && ((String[])value)[0] == REQUEST_VALUE))
456                                    return getBindValue(true, requiredType);
457                            return super.convertIfNecessary(value, requiredType, methodParam);
458                    }
459        }
460    }