001/**
002 *   GRANITE DATA SERVICES
003 *   Copyright (C) 2006-2013 GRANITE DATA SERVICES S.A.S.
004 *
005 *   This file is part of the Granite Data Services Platform.
006 *
007 *   Granite Data Services is free software; you can redistribute it and/or
008 *   modify it under the terms of the GNU Lesser General Public
009 *   License as published by the Free Software Foundation; either
010 *   version 2.1 of the License, or (at your option) any later version.
011 *
012 *   Granite Data Services is distributed in the hope that it will be useful,
013 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
014 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
015 *   General Public License for more details.
016 *
017 *   You should have received a copy of the GNU Lesser General Public
018 *   License along with this library; if not, write to the Free Software
019 *   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
020 *   USA, or see <http://www.gnu.org/licenses/>.
021 */
022package org.granite.tide.spring;
023
024import java.lang.reflect.Method;
025import java.util.ArrayList;
026import java.util.Enumeration;
027import java.util.HashMap;
028import java.util.Hashtable;
029import java.util.List;
030import java.util.Map;
031import java.util.Set;
032
033import javax.servlet.ServletRequest;
034import javax.servlet.http.HttpServletRequest;
035import javax.servlet.http.HttpServletRequestWrapper;
036import javax.servlet.http.HttpServletResponse;
037
038import org.granite.context.GraniteContext;
039import org.granite.logging.Logger;
040import org.granite.messaging.amf.io.convert.Converter;
041import org.granite.messaging.amf.io.util.ClassGetter;
042import org.granite.messaging.service.ServiceException;
043import org.granite.messaging.service.ServiceInvocationContext;
044import org.granite.messaging.webapp.HttpGraniteContext;
045import org.granite.tide.IInvocationCall;
046import org.granite.tide.IInvocationResult;
047import org.granite.tide.annotations.BypassTideMerge;
048import org.granite.tide.data.DataContext;
049import org.granite.tide.invocation.ContextUpdate;
050import org.granite.tide.invocation.InvocationCall;
051import org.granite.tide.invocation.InvocationResult;
052import org.granite.util.TypeUtil;
053import org.springframework.beans.TypeMismatchException;
054import org.springframework.context.ApplicationContext;
055import org.springframework.core.MethodParameter;
056import org.springframework.web.bind.ServletRequestDataBinder;
057import org.springframework.web.context.request.WebRequestInterceptor;
058import org.springframework.web.servlet.HandlerAdapter;
059import org.springframework.web.servlet.HandlerInterceptor;
060import org.springframework.web.servlet.ModelAndView;
061import org.springframework.web.servlet.handler.WebRequestHandlerInterceptorAdapter;
062import org.springframework.web.servlet.mvc.Controller;
063import org.springframework.web.servlet.mvc.SimpleControllerHandlerAdapter;
064import org.springframework.web.servlet.mvc.annotation.AnnotationMethodHandlerAdapter;
065
066
067/**
068 * @author William DRAI
069 */
070public class SpringMVCServiceContext extends SpringServiceContext {
071
072    private static final long serialVersionUID = 1L;
073    
074    private static final String REQUEST_VALUE = "__REQUEST_VALUE__";
075    
076    private static final Logger log = Logger.getLogger(SpringMVCServiceContext.class);
077
078    
079    public SpringMVCServiceContext() throws ServiceException {
080        super();
081    }
082    
083    public SpringMVCServiceContext(ApplicationContext springContext) throws ServiceException {
084        super(springContext);
085    }
086    
087    
088    @Override
089    public Object adjustInvokee(Object instance, String componentName, Set<Class<?>> componentClasses) {
090        for (Class<?> componentClass : componentClasses) {
091                if (componentClass.isAnnotationPresent(org.springframework.stereotype.Controller.class)) {
092                        return new AnnotationMethodHandlerAdapter() {
093                                @Override
094                                protected ServletRequestDataBinder createBinder(HttpServletRequest request, Object target, String objectName) throws Exception {
095                                        return new ControllerRequestDataBinder(request, target, objectName);
096                                }
097                        };
098                }
099        }
100        if (Controller.class.isInstance(instance) || (componentName != null && componentName.endsWith("Controller")))
101                return new SimpleControllerHandlerAdapter();
102        
103        return instance;
104    }
105    
106    
107    private static final String SPRINGMVC_BINDING_ATTR = "__SPRINGMVC_LOCAL_BINDING__";
108    
109    @Override
110    public Object[] beforeMethodSearch(Object instance, String methodName, Object[] args) {
111        if (instance instanceof HandlerAdapter) {
112                boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
113                
114                String componentName = (String)args[0];
115                String componentClassName = (String)args[1];
116            Class<?> componentClass = null;
117            try {
118                if (componentClassName != null)
119                        componentClass = TypeUtil.forName(componentClassName);
120            }
121            catch (ClassNotFoundException e) {
122                throw new ServiceException("Component class not found " + componentClassName, e);
123            }
124                Object component = findComponent(componentName, componentClass);
125                Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
126                Object handler = component;
127                if (grails && componentName.endsWith("Controller")) {
128                        // Special handling for Grails controllers
129                        handler = springContext.getBean("mainSimpleController");
130                }
131                HttpGraniteContext context = (HttpGraniteContext)GraniteContext.getCurrentInstance();
132                @SuppressWarnings("unchecked")
133                Map<String, Object> requestMap = (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length >= 1 && ((Object[])args[3]).length <= 2 && ((Object[])args[3])[0] instanceof Map) 
134                        ? (Map<String, Object>)((Object[])args[3])[0] 
135                        : null;
136                boolean localBinding = false;
137                if (args[3] != null && args[3] instanceof Object[] && ((Object[])args[3]).length == 2 
138                                && ((Object[])args[3])[0] instanceof Map<?, ?> && ((Object[])args[3])[1] instanceof Boolean)
139                        localBinding = (Boolean)((Object[])args[3])[1];
140                context.getRequestMap().put(SPRINGMVC_BINDING_ATTR, localBinding);
141                
142                Map<String, Object> valueMap = null;
143                if (args[4] instanceof InvocationCall) {
144                        valueMap = new HashMap<String, Object>();
145                        for (ContextUpdate u : ((InvocationCall)args[4]).getUpdates())
146                                valueMap.put(u.getComponentName() + (u.getExpression() != null ? "." + u.getExpression() : ""), u.getValue());
147                }
148                
149                if (grails) {
150                                // Special handling for Grails controllers
151                        try {
152                                for (Class<?> cClass : componentClasses) {
153                                        if (cClass.isInterface())
154                                                continue;
155                                        Method m = cClass.getDeclaredMethod("getProperty", String.class);
156                                        @SuppressWarnings("unchecked")
157                                        Map<String, Object> map = (Map<String, Object>)m.invoke(component, "params");
158                                        if (requestMap != null)
159                                                map.putAll(requestMap);
160                                        if (valueMap != null)
161                                                map.putAll(valueMap);
162                                }
163                        }
164                        catch (Exception e) {
165                                // Ignore, probably not a Grails controller
166                        }
167                }
168                ControllerRequestWrapper rw = new ControllerRequestWrapper(grails, context.getRequest(), componentName, (String)args[2], requestMap, valueMap);
169                return new Object[] { "handle", new Object[] { rw, context.getResponse(), handler }};
170        }
171        
172        return super.beforeMethodSearch(instance, methodName, args);
173    }
174    
175    
176    @Override
177    public void prepareCall(ServiceInvocationContext context, IInvocationCall c, String componentName, Class<?> componentClass) {
178        super.prepareCall(context, c, componentName, componentClass);
179                
180        if (componentName == null)
181                return;
182        
183                Object component = findComponent(componentName, componentClass);
184                
185                if (context.getBean() instanceof HandlerAdapter) {
186                        // In case of Spring controllers, call interceptors
187                ApplicationContext webContext = getSpringContext();
188                String[] interceptorNames = webContext.getBeanNamesForType(HandlerInterceptor.class);
189                String[] webRequestInterceptors = webContext.getBeanNamesForType(WebRequestInterceptor.class);
190                HandlerInterceptor[] interceptors = new HandlerInterceptor[interceptorNames.length+webRequestInterceptors.length];
191        
192                int j = 0;
193                for (int i = 0; i < webRequestInterceptors.length; i++)
194                    interceptors[j++] = new WebRequestHandlerInterceptorAdapter((WebRequestInterceptor)webContext.getBean(webRequestInterceptors[i]));
195                for (int i = 0; i < interceptorNames.length; i++)
196                    interceptors[j++] = (HandlerInterceptor)webContext.getBean(interceptorNames[i]);
197                
198                        HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
199                        
200                        graniteContext.getRequestMap().put(HandlerInterceptor.class.getName(), interceptors);
201                        
202                        try {
203                        for (int i = 0; i < interceptors.length; i++) {
204                            HandlerInterceptor interceptor = interceptors[i];
205                            interceptor.preHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component);
206                        }
207                        }
208                        catch (Exception e) {
209                                throw new ServiceException(e.getMessage(), e);
210                        }
211                }
212    }
213    
214
215    @Override
216    @SuppressWarnings("unchecked")
217    public IInvocationResult postCall(ServiceInvocationContext context, Object result, String componentName, Class<?> componentClass) {
218        List<ContextUpdate> results = null;
219        
220        Object component = null;
221        if (componentName != null && context.getBean() instanceof HandlerAdapter) {
222                        component = findComponent(componentName, componentClass);
223                        
224                        HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
225                        
226                Map<String, Object> modelMap = null;
227                if (result instanceof ModelAndView) {
228                        ModelAndView modelAndView = (ModelAndView)result;
229                        modelMap = modelAndView.getModel();
230                        result = modelAndView.getViewName();
231                        
232                        if (context.getBean() instanceof HandlerAdapter) {
233                                try {
234                                        HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
235                                        
236                                        if (interceptors != null) {
237                                        for (int i = interceptors.length-1; i >= 0; i--) {
238                                            HandlerInterceptor interceptor = interceptors[i];
239                                            interceptor.postHandle((HttpServletRequest)context.getParameters()[0], graniteContext.getResponse(), component, modelAndView);
240                                        }
241        
242                                            triggerAfterCompletion(component, interceptors.length-1, interceptors, graniteContext.getRequest(), graniteContext.getResponse(), null);
243                                        }
244                                }
245                                catch (Exception e) {
246                                        throw new ServiceException(e.getMessage(), e);
247                                }
248                        }
249                }
250                
251                if (modelMap != null) {
252                        Boolean localBinding = (Boolean)graniteContext.getRequestMap().get(SPRINGMVC_BINDING_ATTR);
253                        
254                        results = new ArrayList<ContextUpdate>();
255                        for (Map.Entry<String, Object> me : modelMap.entrySet()) {
256                                        if (me.getKey().toString().startsWith("org.springframework.validation.")
257                                                        || (me.getValue() != null && (
258                                                                        me.getValue().getClass().getName().startsWith("groovy.lang.ExpandoMetaClass")
259                                                                        || me.getValue().getClass().getName().indexOf("$_closure") > 0
260                                                                        || me.getValue() instanceof Class)))                                                                    
261                                                continue;
262                                        String variableName = me.getKey().toString();
263                                        if (Boolean.TRUE.equals(localBinding))
264                                                results.add(new ContextUpdate(componentName, variableName, me.getValue(), 3, false));
265                                        else
266                                                results.add(new ContextUpdate(variableName, null, me.getValue(), 3, false));
267                        }
268                }
269                
270                        boolean grails = getSpringContext().getClass().getName().indexOf("Grails") > 0;
271                        if (grails) {
272                                // Special handling for Grails controllers: get flash content
273                                try {
274                                Set<Class<?>> componentClasses = findComponentClasses(componentName, componentClass);
275                                for (Class<?> cClass : componentClasses) {
276                                        if (cClass.isInterface())
277                                                continue;
278                                                Method m = cClass.getDeclaredMethod("getProperty", String.class);
279                                                Map<String, Object> map = (Map<String, Object>)m.invoke(component, "flash");
280                                                if (results == null)
281                                                        results = new ArrayList<ContextUpdate>();
282                                                for (Map.Entry<String, Object> me : map.entrySet()) {
283                                                        Object value = me.getValue();
284                                                        if (value != null && value.getClass().getName().startsWith("org.codehaus.groovy.runtime.GString"))
285                                                                value = value.toString();
286                                                        results.add(new ContextUpdate("flash", me.getKey(), value, 3, false));
287                                                }
288                                }
289                                }
290                                catch (Exception e) {
291                                        throw new ServiceException("Flash scope retrieval failed", e);
292                                }
293                        }
294        }
295                
296        DataContext dataContext = DataContext.get();
297                Object[][] updates = dataContext != null ? dataContext.getUpdates() : null;
298                
299        InvocationResult ires = new InvocationResult(result, results);
300        if (component == null)
301                component = context.getBean();
302        if (isBeanAnnotationPresent(component, BypassTideMerge.class))
303                ires.setMerge(false);
304        else if (!(context.getParameters().length > 0 && context.getParameters()[0] instanceof ControllerRequestWrapper)) {
305                if (isBeanMethodAnnotationPresent(component, context.getMethod().getName(), context.getMethod().getParameterTypes(), BypassTideMerge.class))
306                        ires.setMerge(false);
307        }
308        
309        ires.setUpdates(updates);
310        
311        return ires;
312    }
313
314    @Override
315    public void postCallFault(ServiceInvocationContext context, Throwable t, String componentName, Class<?> componentClass) {
316        if (componentName != null && context.getBean() instanceof HandlerAdapter) {
317                        HttpGraniteContext graniteContext = (HttpGraniteContext)GraniteContext.getCurrentInstance();
318                        
319                        Object component = findComponent(componentName, componentClass);
320                        
321                        HandlerInterceptor[] interceptors = (HandlerInterceptor[])graniteContext.getRequestMap().get(HandlerInterceptor.class.getName());
322        
323                triggerAfterCompletion(component, interceptors.length-1, interceptors, 
324                                graniteContext.getRequest(), graniteContext.getResponse(), 
325                                t instanceof Exception ? (Exception)t : null);
326        }
327                
328        super.postCallFault(context, t, componentName, componentClass);
329    }
330
331    
332    private void triggerAfterCompletion(Object component, int interceptorIndex, HandlerInterceptor[] interceptors, HttpServletRequest request, HttpServletResponse response, Exception ex) {
333                for (int i = interceptorIndex; i >= 0; i--) {
334                        HandlerInterceptor interceptor = interceptors[i];
335                        try {
336                                interceptor.afterCompletion(request, response, component, ex);
337                        }
338                        catch (Throwable ex2) {
339                                log.error("HandlerInterceptor.afterCompletion threw exception", ex2);
340                        }
341                }
342    }
343
344    
345    
346    private class ControllerRequestWrapper extends HttpServletRequestWrapper {
347        private String componentName = null;
348        private String methodName = null;
349        private Map<String, Object> requestMap = null;
350        private Map<String, Object> valueMap = null;
351        
352                public ControllerRequestWrapper(boolean grails, HttpServletRequest request, String componentName, String methodName, Map<String, Object> requestMap, Map<String, Object> valueMap) {
353                        super(request);
354                        this.componentName = componentName.substring(0, componentName.length()-"Controller".length());
355                        if (this.componentName.indexOf(".") > 0)
356                                this.componentName = this.componentName.substring(this.componentName.lastIndexOf(".")+1);
357                        if (grails)
358                                this.componentName = this.componentName.substring(0, 1).toLowerCase() + this.componentName.substring(1);
359                        this.methodName = methodName;
360                        this.requestMap = requestMap;
361                        this.valueMap = valueMap;
362                }
363        
364                @Override
365                public String getRequestURI() {
366                        return getContextPath() + "/" + componentName + "/" + methodName;
367                }
368                
369                @Override
370                public String getServletPath() {
371                        return "/" + componentName + "/" + methodName;
372                }
373                
374                public Object getRequestValue(String key) {
375                        return requestMap != null ? requestMap.get(key) : null;
376                }
377                
378                public Object getBindValue(String key) {
379                        return valueMap != null ? valueMap.get(key) : null;
380                }
381                
382                @Override
383                public String getParameter(String name) {
384                        return requestMap != null && requestMap.containsKey(name) ? REQUEST_VALUE : null;
385                }
386                
387                @Override
388                public String[] getParameterValues(String name) {
389                        return requestMap != null && requestMap.containsKey(name) ? new String[] { REQUEST_VALUE } : null;
390                }
391
392                @Override
393                @SuppressWarnings({ "unchecked", "rawtypes" })
394                public Map getParameterMap() {
395                        Map<String, Object> pmap = new HashMap<String, Object>();
396                        if (requestMap != null) {
397                                for (String name : requestMap.keySet())
398                                        pmap.put(name, REQUEST_VALUE);
399                        }
400                        return pmap; 
401                }
402
403                @Override
404                @SuppressWarnings({ "unchecked", "rawtypes" })
405                public Enumeration getParameterNames() {
406                        Hashtable ht = new Hashtable();
407                        if (requestMap != null)
408                                ht.putAll(requestMap);
409                        return ht.keys();
410                }
411    }
412    
413    
414    private class ControllerRequestDataBinder extends ServletRequestDataBinder {
415        
416        private ControllerRequestWrapper wrapper = null;
417        private Object target = null;
418
419                public ControllerRequestDataBinder(ServletRequest request, Object target, String objectName) {
420                        super(target, objectName);
421                        this.wrapper = (ControllerRequestWrapper)request;
422                        this.target = target;
423                }
424                
425                private Object getBindValue(boolean request, Class<?> requiredType) {
426                        GraniteContext context = GraniteContext.getCurrentInstance();
427                        ClassGetter classGetter = context.getGraniteConfig().getClassGetter();
428                        Object value = request ? wrapper.getRequestValue(getObjectName()) : wrapper.getBindValue(getObjectName());
429                        if (requiredType != null) {
430                                Converter converter = context.getGraniteConfig().getConverters().getConverter(value, requiredType);
431                                if (converter != null)
432                                        value = converter.convert(value, requiredType);
433                        }
434                        if (value != null && !request)
435                                return SpringMVCServiceContext.this.mergeExternal(classGetter, value, null, null, null);
436                        return value;
437                }
438
439                @Override
440                public void bind(ServletRequest request) {
441                        Object value = getBindValue(false, null);
442                        if (value != null)
443                                target = value;
444                }
445                
446                @Override
447                public Object getTarget() {
448                        return target;                  
449                }
450                
451                @SuppressWarnings({ "rawtypes", "unchecked" })
452                @Override
453                public Object convertIfNecessary(Object value, Class requiredType, MethodParameter methodParam) throws TypeMismatchException {
454                        if (target == null && value == REQUEST_VALUE || (value instanceof String[] && ((String[])value)[0] == REQUEST_VALUE))
455                                return getBindValue(true, requiredType);
456                        return super.convertIfNecessary(value, requiredType, methodParam);
457                }
458    }
459}