001    /*
002      GRANITE DATA SERVICES
003      Copyright (C) 2011 GRANITE DATA SERVICES S.A.S.
004    
005      This file is part of Granite Data Services.
006    
007      Granite Data Services is free software; you can redistribute it and/or modify
008      it under the terms of the GNU Library General Public License as published by
009      the Free Software Foundation; either version 2 of the License, or (at your
010      option) any later version.
011    
012      Granite Data Services is distributed in the hope that it will be useful, but
013      WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
014      FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License
015      for more details.
016    
017      You should have received a copy of the GNU Library General Public License
018      along with this library; if not, see <http://www.gnu.org/licenses/>.
019    */
020    
021    package org.granite.tide.spring;
022    
023    import java.lang.reflect.Method;
024    import java.util.ArrayList;
025    import java.util.Enumeration;
026    import java.util.HashMap;
027    import java.util.Hashtable;
028    import java.util.List;
029    import java.util.Map;
030    import java.util.Set;
031    
032    import javax.servlet.ServletRequest;
033    import javax.servlet.http.HttpServletRequest;
034    import javax.servlet.http.HttpServletRequestWrapper;
035    import javax.servlet.http.HttpServletResponse;
036    
037    import org.granite.context.GraniteContext;
038    import org.granite.logging.Logger;
039    import org.granite.messaging.amf.io.convert.Converter;
040    import org.granite.messaging.amf.io.util.ClassGetter;
041    import org.granite.messaging.service.ServiceException;
042    import org.granite.messaging.service.ServiceInvocationContext;
043    import org.granite.messaging.webapp.HttpGraniteContext;
044    import org.granite.tide.IInvocationCall;
045    import org.granite.tide.IInvocationResult;
046    import org.granite.tide.TidePersistenceManager;
047    import org.granite.tide.annotations.BypassTideMerge;
048    import org.granite.tide.data.DataContext;
049    import org.granite.tide.invocation.ContextUpdate;
050    import org.granite.tide.invocation.InvocationCall;
051    import org.granite.tide.invocation.InvocationResult;
052    import org.granite.util.ClassUtil;
053    import org.springframework.beans.TypeMismatchException;
054    import org.springframework.context.ApplicationContext;
055    import org.springframework.core.MethodParameter;
056    import org.springframework.web.bind.ServletRequestDataBinder;
057    import org.springframework.web.context.request.WebRequestInterceptor;
058    import org.springframework.web.servlet.HandlerAdapter;
059    import org.springframework.web.servlet.HandlerInterceptor;
060    import org.springframework.web.servlet.ModelAndView;
061    import org.springframework.web.servlet.handler.WebRequestHandlerInterceptorAdapter;
062    import org.springframework.web.servlet.mvc.Controller;
063    import org.springframework.web.servlet.mvc.SimpleControllerHandlerAdapter;
064    import org.springframework.web.servlet.mvc.annotation.AnnotationMethodHandlerAdapter;
065    
066    
067    /**
068     * @author William DRAI
069     */
070    public class SpringMVCServiceContext extends SpringServiceContext {
071    
072        private static final long serialVersionUID = 1L;
073        
074        private static final String REQUEST_VALUE = "__REQUEST_VALUE__";
075        
076        private static final Logger log = Logger.getLogger(SpringMVCServiceContext.class);
077    
078        
079        @Override
080        public Object adjustInvokee(Object instance, String componentName, Set<Class<?>> componentClasses) {
081            for (Class<?> componentClass : componentClasses) {
082                    if (componentClass.isAnnotationPresent(org.springframework.stereotype.Controller.class)) {
083                            return new AnnotationMethodHandlerAdapter() {
084                                    @Override
085                                    protected ServletRequestDataBinder createBinder(HttpServletRequest request, Object target, String objectName) throws Exception {
086                                            return new ControllerRequestDataBinder(request, target, objectName);
087                                    }
088                            };
089                    }
090            }
091            if (Controller.class.isInstance(instance) || componentName.endsWith("Controller"))
092                    return new SimpleControllerHandlerAdapter();
093            
094            return instance;
095        }
096        
097        
098        private static final String SPRINGMVC_BINDING_ATTR = "__SPRINGMVC_LOCAL_BINDING__";
099        
100        @Override
101        public Object[] beforeMethodSearch(Object instance, String methodName, Object[] args) {
102            if (instance instanceof HandlerAdapter) {
103                    boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
104                    
105                    String componentName = (String)args[0];
106                    String componentClassName = (String)args[1];
107                Class<?> componentClass = null;
108                try {
109                    if (componentClassName != null)
110                            componentClass = ClassUtil.forName(componentClassName);
111                }
112                catch (ClassNotFoundException e) {
113                    throw new ServiceException("Component class not found " + componentClassName, e);
114                }
115                    Object component = findComponent(componentName, componentClass);
116                    Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
117                    Object handler = component;
118                    if (grails && componentName.endsWith("Controller")) {
119                            // Special handling for Grails controllers
120                            handler = springContext.getBean("mainSimpleController");
121                    }
122                    HttpGraniteContext context = (HttpGraniteContext)GraniteContext.getCurrentInstance();
123                    @SuppressWarnings("unchecked")
124                    Map<String, Object> requestMap = (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length >= 1 && ((Object[])args[3]).length <= 2 && ((Object[])args[3])[0] instanceof Map) 
125                            ? (Map<String, Object>)((Object[])args[3])[0] 
126                            : null;
127                    boolean localBinding = false;
128                    if (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length == 2 
129                                    && ((Object[])args[3])[0] instanceof Map<?, ?> && ((Object[])args[3])[1] instanceof Boolean)
130                            localBinding = (Boolean)((Object[])args[3])[1];
131                    context.getRequestMap().put(SPRINGMVC_BINDING_ATTR, localBinding);
132                    
133                    Map<String, Object> valueMap = null;
134                    if (args[4] instanceof InvocationCall) {
135                            valueMap = new HashMap<String, Object>();
136                            for (ContextUpdate u : ((InvocationCall)args[4]).getUpdates())
137                                    valueMap.put(u.getComponentName() + (u.getExpression() != null ? "." + u.getExpression() : ""), u.getValue());
138                    }
139                    
140                    if (grails) {
141                                    // Special handling for Grails controllers
142                            try {
143                                    for (Class<?> cClass : componentClasses) {
144                                            if (cClass.isInterface())
145                                                    continue;
146                                            Method m = cClass.getDeclaredMethod("getProperty", String.class);
147                                            @SuppressWarnings("unchecked")
148                                            Map<String, Object> map = (Map<String, Object>)m.invoke(component, "params");
149                                            if (requestMap != null)
150                                                    map.putAll(requestMap);
151                                            if (valueMap != null)
152                                                    map.putAll(valueMap);
153                                    }
154                            }
155                            catch (Exception e) {
156                                    // Ignore, probably not a Grails controller
157                            }
158                    }
159                    ControllerRequestWrapper rw = new ControllerRequestWrapper(grails, context.getRequest(), componentName, (String)args[2], requestMap, valueMap);
160                    return new Object[] { "handle", new Object[] { rw, context.getResponse(), handler }};
161            }
162            
163            return super.beforeMethodSearch(instance, methodName, args);
164        }
165        
166        
167        @Override
168        public void prepareCall(ServiceInvocationContext context, IInvocationCall c, String componentName, Class<?> componentClass) {
169            super.prepareCall(context, c, componentName, componentClass);
170                    
171            if (componentName == null)
172                    return;
173            
174                    Object component = findComponent(componentName, componentClass);
175                    
176                    if (context.getBean() instanceof HandlerAdapter) {
177                            // In case of Spring controllers, call interceptors
178                    ApplicationContext webContext = getSpringContext();
179                    String[] interceptorNames = webContext.getBeanNamesForType(HandlerInterceptor.class);
180                    String[] webRequestInterceptors = webContext.getBeanNamesForType(WebRequestInterceptor.class);
181                    HandlerInterceptor[] interceptors = new HandlerInterceptor[interceptorNames.length+webRequestInterceptors.length];
182            
183                    int j = 0;
184                    for (int i = 0; i < webRequestInterceptors.length; i++)
185                        interceptors[j++] = new WebRequestHandlerInterceptorAdapter((WebRequestInterceptor)webContext.getBean(webRequestInterceptors[i]));
186                    for (int i = 0; i < interceptorNames.length; i++)
187                        interceptors[j++] = (HandlerInterceptor)webContext.getBean(interceptorNames[i]);
188                    
189                            HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
190                            
191                            graniteContext.getRequestMap().put(HandlerInterceptor.class.getName(), interceptors);
192                            
193                            try {
194                            for (int i = 0; i < interceptors.length; i++) {
195                                HandlerInterceptor interceptor = interceptors[i];
196                                interceptor.preHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component);
197                            }
198                            }
199                            catch (Exception e) {
200                                    throw new ServiceException(e.getMessage(), e);
201                            }
202                    }
203        }
204        
205    
206        @Override
207        @SuppressWarnings("unchecked")
208        public IInvocationResult postCall(ServiceInvocationContext context, Object result, String componentName, Class<?> componentClass) {
209            List<ContextUpdate> results = null;
210            
211            Object component = null;
212            if (componentName != null && context.getBean() instanceof HandlerAdapter) {
213                            component = findComponent(componentName, componentClass);
214                            
215                            HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
216                            
217                    Map<Object, Object> modelMap = null;
218                    if (result instanceof ModelAndView) {
219                            ModelAndView modelAndView = (ModelAndView)result;
220                            modelMap = modelAndView.getModel();
221                            result = modelAndView.getViewName();
222                            
223                            if (context.getBean() instanceof HandlerAdapter) {
224                                    try {
225                                            HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
226                                            
227                                            if (interceptors != null) {
228                                            for (int i = interceptors.length-1; i >= 0; i--) {
229                                                HandlerInterceptor interceptor = interceptors[i];
230                                                interceptor.postHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component, modelAndView);
231                                            }
232            
233                                                triggerAfterCompletion(component, interceptors.length-1, interceptors, graniteContext.getRequest(), graniteContext.getResponse(), null);
234                                            }
235                                    }
236                                    catch (Exception e) {
237                                            throw new ServiceException(e.getMessage(), e);
238                                    }
239                            }
240                    }
241                    
242                    if (modelMap != null) {
243                            Boolean localBinding = (Boolean)graniteContext.getRequestMap().get(SPRINGMVC_BINDING_ATTR);
244                            
245                            results = new ArrayList<ContextUpdate>();
246                            for (Map.Entry<Object, Object> me : modelMap.entrySet()) {
247                                            if (me.getKey().toString().startsWith("org.springframework.validation.")
248                                                            || (me.getValue() != null && (
249                                                                            me.getValue().getClass().getName().startsWith("groovy.lang.ExpandoMetaClass")
250                                                                            || me.getValue().getClass().getName().indexOf("$_closure") > 0
251                                                                            || me.getValue() instanceof Class)))                                                                    
252                                                    continue;
253                                            String variableName = me.getKey().toString();
254                                            if (Boolean.TRUE.equals(localBinding))
255                                                    results.add(new ContextUpdate(componentName, variableName, me.getValue(), 3, false));
256                                            else
257                                                    results.add(new ContextUpdate(variableName, null, me.getValue(), 3, false));
258                            }
259                    }
260                    
261                            boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
262                            if (grails) {
263                                    // Special handling for Grails controllers: get flash content
264                                    try {
265                                    Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
266                                    for (Class<?> cClass : componentClasses) {
267                                            if (cClass.isInterface())
268                                                    continue;
269                                                    Method m = cClass.getDeclaredMethod("getProperty", String.class);
270                                                    Map<String, Object> map = (Map<String, Object>)m.invoke(component, "flash");
271                                                    if (results == null)
272                                                            results = new ArrayList<ContextUpdate>();
273                                                    for (Map.Entry<String, Object> me : map.entrySet()) {
274                                                            Object value = me.getValue();
275                                                            if (value != null && value.getClass().getName().startsWith("org.codehaus.groovy.runtime.GString"))
276                                                                    value = value.toString();
277                                                            results.add(new ContextUpdate("flash", me.getKey(), value, 3, false));
278                                                    }
279                                    }
280                                    }
281                                    catch (Exception e) {
282                                            throw new ServiceException("Flash scope retrieval failed", e);
283                                    }
284                            }
285            }
286                    
287            DataContext dataContext = DataContext.get();
288                    Set<Object[]> dataUpdates = dataContext != null ? dataContext.getDataUpdates() : null;
289                    Object[][] updates = null;
290                    if (dataUpdates != null && !dataUpdates.isEmpty())
291                            updates = dataUpdates.toArray(new Object[dataUpdates.size()][]);
292                    
293            InvocationResult ires = new InvocationResult(result, results);
294            if (component == null)
295                    component = context.getBean();
296            if (component.getClass().isAnnotationPresent(BypassTideMerge.class))
297                    ires.setMerge(false);
298            else if (!(context.getParameters().length > 0 && context.getParameters()[0] instanceof ControllerRequestWrapper)) {
299                    try {
300                            Method m = component.getClass().getMethod(context.getMethod().getName(), context.getMethod().getParameterTypes());
301                            if (m.isAnnotationPresent(BypassTideMerge.class))
302                                    ires.setMerge(false);
303                    }
304                    catch (Exception e) {
305                            log.warn("Could not find bean method", e);
306                    }
307            }
308            
309            ires.setUpdates(updates);
310            
311            return ires;
312        }
313    
314        @Override
315        public void postCallFault(ServiceInvocationContext context, Throwable t, String componentName, Class<?> componentClass) {
316            if (componentName != null && context.getBean() instanceof HandlerAdapter) {
317                            HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
318                            
319                            Object component = findComponent(componentName, componentClass);
320                            
321                            HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
322            
323                    triggerAfterCompletion(component, interceptors.length-1, interceptors, 
324                                    graniteContext.getRequest(), graniteContext.getResponse(), 
325                                    t instanceof Exception ? (Exception)t : null);
326            }
327                    
328            super.postCallFault(context, t, componentName, componentClass);
329        }
330    
331        
332        private void triggerAfterCompletion(Object component, int interceptorIndex, HandlerInterceptor[] interceptors, HttpServletRequest request, HttpServletResponse response, Exception ex) {
333                    for (int i = interceptorIndex; i >= 0; i--) {
334                            HandlerInterceptor interceptor = interceptors[i];
335                            try {
336                                    interceptor.afterCompletion(request, response, component, ex);
337                            }
338                            catch (Throwable ex2) {
339                                    log.error("HandlerInterceptor.afterCompletion threw exception", ex2);
340                            }
341                    }
342        }
343    
344        
345        
346        private class ControllerRequestWrapper extends HttpServletRequestWrapper {
347            private String componentName = null;
348            private String methodName = null;
349            private Map<String, Object> requestMap = null;
350            private Map<String, Object> valueMap = null;
351            
352                    public ControllerRequestWrapper(boolean grails, HttpServletRequest request, String componentName, String methodName, Map<String, Object> requestMap, Map<String, Object> valueMap) {
353                            super(request);
354                            this.componentName = componentName.substring(0, componentName.length()-"Controller".length());
355                            if (this.componentName.indexOf(".") > 0)
356                                    this.componentName = this.componentName.substring(this.componentName.lastIndexOf(".")+1);
357                            if (grails)
358                                    this.componentName = this.componentName.substring(0, 1).toLowerCase() + this.componentName.substring(1);
359                            this.methodName = methodName;
360                            this.requestMap = requestMap;
361                            this.valueMap = valueMap;
362                    }
363            
364                    @Override
365                    public String getRequestURI() {
366                            return getContextPath() + "/" + componentName + "/" + methodName;
367                    }
368                    
369                    @Override
370                    public String getServletPath() {
371                            return "/" + componentName + "/" + methodName;
372                    }
373                    
374                    public Object getRequestValue(String key) {
375                            return requestMap != null ? requestMap.get(key) : null;
376                    }
377                    
378                    public Object getBindValue(String key) {
379                            return valueMap != null ? valueMap.get(key) : null;
380                    }
381                    
382                    @Override
383                    public String getParameter(String name) {
384                            return requestMap != null && requestMap.containsKey(name) ? REQUEST_VALUE : null;
385                    }
386                    
387                    @Override
388                    public String[] getParameterValues(String name) {
389                            return requestMap != null && requestMap.containsKey(name) ? new String[] { REQUEST_VALUE } : null;
390                    }
391    
392                    @Override
393                    @SuppressWarnings({ "unchecked", "rawtypes" })
394                    public Map getParameterMap() {
395                            Map<String, Object> pmap = new HashMap<String, Object>();
396                            if (requestMap != null) {
397                                    for (String name : requestMap.keySet())
398                                            pmap.put(name, REQUEST_VALUE);
399                            }
400                            return pmap; 
401                    }
402    
403                    @Override
404                    @SuppressWarnings({ "unchecked", "rawtypes" })
405                    public Enumeration getParameterNames() {
406                            Hashtable ht = new Hashtable();
407                            if (requestMap != null)
408                                    ht.putAll(requestMap);
409                            return ht.keys();
410                    }
411        }
412        
413        
414        private class ControllerRequestDataBinder extends ServletRequestDataBinder {
415            
416            private ControllerRequestWrapper wrapper = null;
417            private Object target = null;
418    
419                    public ControllerRequestDataBinder(ServletRequest request, Object target, String objectName) {
420                            super(target, objectName);
421                            this.wrapper = (ControllerRequestWrapper)request;
422                            this.target = target;
423                    }
424                    
425                    private Object getBindValue(boolean request, Class<?> requiredType) {
426                            GraniteContext context = GraniteContext.getCurrentInstance();
427                            TidePersistenceManager pm = SpringMVCServiceContext.this.getTidePersistenceManager(true);
428                            ClassGetter classGetter = context.getGraniteConfig().getClassGetter();
429                            Object value = request ? wrapper.getRequestValue(getObjectName()) : wrapper.getBindValue(getObjectName());
430                            if (requiredType != null) {
431                                    Converter converter = context.getGraniteConfig().getConverters().getConverter(value, requiredType);
432                                    if (converter != null)
433                                            value = converter.convert(value, requiredType);
434                            }
435                            if (value != null && !request)
436                                    return SpringMVCServiceContext.this.mergeExternal(pm, classGetter, value, null, null, null);
437                            return value;
438                    }
439    
440                    @Override
441                    public void bind(ServletRequest request) {
442                            Object value = getBindValue(false, null);
443                            if (value != null)
444                                    target = value;
445                    }
446                    
447                    @Override
448                    public Object getTarget() {
449                            return target;                  
450                    }
451                    
452                    @SuppressWarnings("rawtypes")
453                    @Override
454                    public Object convertIfNecessary(Object value, Class requiredType, MethodParameter methodParam) throws TypeMismatchException {
455                            if (target == null && value == REQUEST_VALUE || (value instanceof String[] && ((String[])value)[0] == REQUEST_VALUE))
456                                    return getBindValue(true, requiredType);
457                            return super.convertIfNecessary(value, requiredType, methodParam);
458                    }
459        }
460    }