001    /**
002     *   GRANITE DATA SERVICES
003     *   Copyright (C) 2006-2013 GRANITE DATA SERVICES S.A.S.
004     *
005     *   This file is part of the Granite Data Services Platform.
006     *
007     *   Granite Data Services is free software; you can redistribute it and/or
008     *   modify it under the terms of the GNU Lesser General Public
009     *   License as published by the Free Software Foundation; either
010     *   version 2.1 of the License, or (at your option) any later version.
011     *
012     *   Granite Data Services is distributed in the hope that it will be useful,
013     *   but WITHOUT ANY WARRANTY; without even the implied warranty of
014     *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
015     *   General Public License for more details.
016     *
017     *   You should have received a copy of the GNU Lesser General Public
018     *   License along with this library; if not, write to the Free Software
019     *   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
020     *   USA, or see <http://www.gnu.org/licenses/>.
021     */
022    package org.granite.tide.spring;
023    
024    import java.lang.annotation.Annotation;
025    import java.lang.reflect.Constructor;
026    import java.lang.reflect.InvocationTargetException;
027    import java.lang.reflect.Method;
028    import java.util.HashSet;
029    import java.util.List;
030    import java.util.Map;
031    import java.util.Set;
032    
033    import javax.servlet.ServletContext;
034    
035    import org.granite.context.GraniteContext;
036    import org.granite.logging.Logger;
037    import org.granite.messaging.service.ServiceException;
038    import org.granite.messaging.service.ServiceInvocationContext;
039    import org.granite.messaging.webapp.HttpGraniteContext;
040    import org.granite.tide.IInvocationCall;
041    import org.granite.tide.IInvocationResult;
042    import org.granite.tide.TidePersistenceManager;
043    import org.granite.tide.TideServiceContext;
044    import org.granite.tide.TideTransactionManager;
045    import org.granite.tide.annotations.BypassTideMerge;
046    import org.granite.tide.async.AsyncPublisher;
047    import org.granite.tide.data.DataContext;
048    import org.granite.tide.data.DataUpdatePostprocessor;
049    import org.granite.tide.invocation.ContextUpdate;
050    import org.granite.tide.invocation.InvocationResult;
051    import org.granite.util.TypeUtil;
052    import org.springframework.aop.framework.Advised;
053    import org.springframework.aop.support.AopUtils;
054    import org.springframework.beans.BeansException;
055    import org.springframework.beans.factory.NoSuchBeanDefinitionException;
056    import org.springframework.context.ApplicationContext;
057    import org.springframework.context.ApplicationContextAware;
058    import org.springframework.orm.jpa.EntityManagerFactoryInfo;
059    import org.springframework.transaction.PlatformTransactionManager;
060    import org.springframework.web.context.support.WebApplicationContextUtils;
061    
062    
063    /**
064     *  @author Sebastien Deleuze
065     *      @author William DRAI
066     */
067    public class SpringServiceContext extends TideServiceContext implements ApplicationContextAware {
068    
069        private static final long serialVersionUID = 1L;
070        
071        protected transient ApplicationContext springContext = null;
072        
073        private String persistenceManagerBeanName = null;
074        private String entityManagerFactoryBeanName = null;
075        
076        private static final Logger log = Logger.getLogger(SpringServiceContext.class);
077                    
078        public SpringServiceContext() throws ServiceException {
079            super();
080            
081            log.debug("Getting spring context from container");
082            getSpringContext();
083        }
084        
085        public SpringServiceContext(ApplicationContext springContext) throws ServiceException {
086            super();
087            
088            this.springContext = springContext;
089        }
090        
091        public void setApplicationContext(ApplicationContext springContext) {
092            this.springContext = springContext;
093        }
094    
095        protected ApplicationContext getSpringContext() {
096            if (springContext == null) {
097                GraniteContext context = GraniteContext.getCurrentInstance();
098                ServletContext sc = ((HttpGraniteContext)context).getServletContext();
099                springContext = WebApplicationContextUtils.getRequiredWebApplicationContext(sc);
100            }
101            return springContext;           
102        }
103        
104        
105        @Override
106        protected AsyncPublisher getAsyncPublisher() {
107            return null;
108        }    
109        
110        @Override
111        public Object findComponent(String componentName, Class<?> componentClass) {
112            Object bean = null;
113            String key = COMPONENT_ATTR + (componentName != null ? componentName : "_CLASS_" + componentClass.getName());
114            
115            GraniteContext context = GraniteContext.getCurrentInstance();
116            if (context != null) {
117                    bean = context.getRequestMap().get(key);
118                    if (bean != null)
119                            return bean;
120            }
121            
122            ApplicationContext springContext = getSpringContext();
123            try {
124                    if (componentClass != null) {
125                            Map<String, ?> beans = springContext.getBeansOfType(componentClass);
126                            if (beans.size() == 1)
127                                    bean = beans.values().iterator().next();
128                            else if (beans.size() > 1 && componentName != null && !("".equals(componentName))) {
129                                    if (beans.containsKey(componentName))
130                                            bean = beans.get(componentName);
131                            }
132                            else if (beans.isEmpty() && springContext.getClass().getName().indexOf("Grails") > 0 && componentClass.getName().endsWith("Service")) {
133                                    try {
134                                            Object serviceClass = springContext.getBean(componentClass.getName() + "ServiceClass");                         
135                                            Method m = serviceClass.getClass().getMethod("getPropertyName");
136                                            String compName = (String)m.invoke(serviceClass);
137                                            bean = springContext.getBean(compName);
138                                    }
139                                    catch (NoSuchMethodException e) {
140                                            log.error(e, "Could not get service class for %s", componentClass.getName());
141                                    }
142                                    catch (InvocationTargetException e) {
143                                            log.error(e.getCause(), "Could not get service class for %s", componentClass.getName());
144                                    }
145                                    catch (IllegalAccessException e) {
146                                            log.error(e.getCause(), "Could not get service class for %s", componentClass.getName());
147                                    }
148                            }
149                    }
150                    if (bean == null && componentName != null && !("".equals(componentName)))
151                            bean = springContext.getBean(componentName);
152                    
153                if (context != null)
154                    context.getRequestMap().put(key, bean);
155                return bean;
156            }
157            catch (NoSuchBeanDefinitionException nexc) {
158                    if (componentName != null && componentName.endsWith("Controller")) {
159                            try {
160                                    int idx = componentName.lastIndexOf(".");
161                                    String controllerName = idx > 0 
162                                            ? componentName.substring(0, idx+1) + componentName.substring(idx+1, idx+2).toUpperCase() + componentName.substring(idx+2)
163                                            : componentName.substring(0, 1).toUpperCase() + componentName.substring(1);
164                                    bean = getSpringContext().getBean(controllerName);
165                        if (context != null)
166                            context.getRequestMap().put(key, bean);
167                                    return bean;
168                            }
169                    catch (NoSuchBeanDefinitionException nexc2) {
170                    }
171                    }
172                    
173                String msg = "Spring service named '" + componentName + "' does not exist.";
174                ServiceException e = new ServiceException(msg, nexc);
175                throw e;
176            } 
177            catch (BeansException bexc) {
178                String msg = "Unable to create Spring service named '" + componentName + "'";
179                ServiceException e = new ServiceException(msg, bexc);
180                throw e;
181            }    
182        }
183        
184        @Override
185        @SuppressWarnings("unchecked")
186        public Set<Class<?>> findComponentClasses(String componentName, Class<?> componentClass) {
187            String key = COMPONENT_CLASS_ATTR + componentName;
188            Set<Class<?>> classes = null; 
189            GraniteContext context = GraniteContext.getCurrentInstance();
190            if (context != null) {
191                    classes = (Set<Class<?>>)context.getRequestMap().get(key);
192                    if (classes != null)
193                            return classes;
194            }
195            
196            Object bean = findComponent(componentName, componentClass);
197            classes = buildComponentClasses(bean);        
198            if (classes == null)
199                    return null;
200            
201            if (context != null)
202                    context.getRequestMap().put(key, classes);
203            return classes;
204        }
205        
206        protected Set<Class<?>> buildComponentClasses(Object bean) {
207            Set<Class<?>> classes = new HashSet<Class<?>>();
208            for (Class<?> i : bean.getClass().getInterfaces())
209                    classes.add(i);
210            
211            try {
212                    while (bean instanceof Advised)
213                            bean = ((Advised)bean).getTargetSource().getTarget();
214                    
215                    classes.add(AopUtils.getTargetClass(bean));
216            }
217            catch (Exception e) {
218                log.warn(e, "Could not get AOP class for component " + bean.getClass());
219                    return null;
220            }
221            
222            return classes;
223        }
224        
225        protected boolean isBeanAnnotationPresent(Object bean, Class<? extends Annotation> annotationClass) {
226            if (bean.getClass().isAnnotationPresent(annotationClass))
227                    return true;
228            
229            try {
230                    while (bean instanceof Advised)
231                            bean = ((Advised)bean).getTargetSource().getTarget();
232                    
233                    if (AopUtils.getTargetClass(bean).isAnnotationPresent(annotationClass))
234                            return true;
235            }
236            catch (Exception e) {
237                log.warn(e, "Could not get AOP class for component " + bean.getClass());
238            }
239            
240            return false;
241        }
242        
243        protected boolean isBeanMethodAnnotationPresent(Object bean, String methodName, Class<?>[] methodArgTypes, Class<? extends Annotation> annotationClass) {
244            try {
245                    Method m = bean.getClass().getMethod(methodName, methodArgTypes);
246                    if (m.isAnnotationPresent(annotationClass))
247                            return true;
248                    
249                    while (bean instanceof Advised)
250                            bean = ((Advised)bean).getTargetSource().getTarget();
251                    
252                    m = AopUtils.getTargetClass(bean).getMethod(methodName, methodArgTypes);
253                    if (m.isAnnotationPresent(annotationClass))
254                            return true;
255                    }
256            catch (Exception e) {
257                    log.warn("Could not find bean method", e);
258            }
259            
260            return false;
261        }
262    
263        
264        @Override
265        public void prepareCall(ServiceInvocationContext context, IInvocationCall c, String componentName, Class<?> componentClass) {
266            DataContext.init();
267            
268            DataUpdatePostprocessor dupp = (DataUpdatePostprocessor)findComponent(null, DataUpdatePostprocessor.class);
269            if (dupp != null)
270                    DataContext.get().setDataUpdatePostprocessor(dupp);
271        }
272    
273        @Override
274        public IInvocationResult postCall(ServiceInvocationContext context, Object result, String componentName, Class<?> componentClass) {
275                    List<ContextUpdate> results = null;
276            DataContext dataContext = DataContext.get();
277                    Object[][] updates = dataContext != null ? dataContext.getUpdates() : null;
278                    
279            InvocationResult ires = new InvocationResult(result, results);
280            if (isBeanAnnotationPresent(context.getBean(), BypassTideMerge.class))
281                    ires.setMerge(false);
282            else if (isBeanMethodAnnotationPresent(context.getBean(), context.getMethod().getName(), context.getMethod().getParameterTypes(), BypassTideMerge.class))
283                            ires.setMerge(false);
284            
285            ires.setUpdates(updates);
286            
287            return ires;
288        }
289    
290        @Override
291        public void postCallFault(ServiceInvocationContext context, Throwable t, String componentName, Class<?> componentClass) {        
292        }
293        
294        
295        public void setEntityManagerFactoryBeanName(String beanName) {
296            this.entityManagerFactoryBeanName = beanName;
297        }
298        
299        public void setPersistenceManagerBeanName(String beanName) {
300            this.persistenceManagerBeanName = beanName;
301        }
302        
303        /**
304         *  Create a TidePersistenceManager
305         *  
306         *  @param create create if not existent (can be false for use in entity merge)
307         *  @return a PersistenceContextManager
308         */
309        @Override
310        protected TidePersistenceManager getTidePersistenceManager(boolean create) {
311            if (!create)
312                return null;
313            
314            TidePersistenceManager pm = (TidePersistenceManager)GraniteContext.getCurrentInstance().getRequestMap().get(TidePersistenceManager.class.getName());
315            if (pm != null)
316                    return pm;
317            
318            pm = createPersistenceManager();
319            GraniteContext.getCurrentInstance().getRequestMap().put(TidePersistenceManager.class.getName(), pm);
320            return pm;
321        }
322        
323        private TidePersistenceManager createPersistenceManager() {
324            if (persistenceManagerBeanName == null) {
325                    if (entityManagerFactoryBeanName == null) {
326                            // No bean or entity manager factory specified 
327                            
328                            // 1. Look for a TidePersistenceManager bean
329                            Map<String, ?> pms = springContext.getBeansOfType(TidePersistenceManager.class);
330                            if (pms.size() > 1)
331                                    throw new RuntimeException("More than one Tide persistence managers defined");
332                            
333                            if (pms.size() == 1)
334                                    return (TidePersistenceManager)pms.values().iterator().next();
335                            
336                            // 2. If not found, try to determine the Spring transaction manager                     
337                            Map<String, ?> tms = springContext.getBeansOfType(PlatformTransactionManager.class);
338                            if (tms.isEmpty())
339                                    log.debug("No Spring transaction manager found, specify a persistence-manager-bean-name or entity-manager-factory-bean-name");
340                            else if (tms.size() > 1)
341                                    log.debug("More than one Spring transaction manager found, specify a persistence-manager-bean-name or entity-manager-factory-bean-name");
342                            else if (tms.size() == 1) {
343                                    PlatformTransactionManager ptm = (PlatformTransactionManager)tms.values().iterator().next();
344                                            
345                                    // If no entity manager, we define a Spring persistence manager 
346                                            // that will try to infer persistence info from the Spring transaction manager
347                                            return new SpringPersistenceManager(ptm);
348                            }
349                    }
350                    
351                String emfBeanName = entityManagerFactoryBeanName != null ? entityManagerFactoryBeanName : "entityManagerFactory";
352                try {
353                    // Lookup the specified entity manager factory
354                    Object emf = findComponent(emfBeanName, null);
355                    
356                    // Try to determine the Spring transaction manager
357                    TideTransactionManager tm = null;
358                            Map<String, ?> ptms = springContext.getBeansOfType(PlatformTransactionManager.class);
359                            if (ptms.size() == 1) {
360                                    log.debug("Found Spring transaction manager " + ptms.keySet().iterator().next());
361                                    tm = new SpringTransactionManager((PlatformTransactionManager)ptms.values().iterator().next());
362                            }
363                    
364                                    Class<?> emfClass = TypeUtil.forName("javax.persistence.EntityManagerFactory");
365                        Class<?> pcmClass = TypeUtil.forName("org.granite.tide.data.JPAPersistenceManager");
366                        Constructor<?>[] cs = pcmClass.getConstructors();
367                        if (tm != null) {
368                            for (Constructor<?> c : cs) {
369                                    if (c.getParameterTypes().length == 2 && emfClass.isAssignableFrom(c.getParameterTypes()[0])
370                                            && TideTransactionManager.class.isAssignableFrom(c.getParameterTypes()[1])) {
371                                            log.debug("Created JPA persistence manager with Spring transaction manager");
372                                            return (TidePersistenceManager)c.newInstance(((EntityManagerFactoryInfo)emf).getNativeEntityManagerFactory(), tm);
373                                    }
374                            }
375                        }
376                        else {
377                                for (Constructor<?> c : cs) {
378                                    if (c.getParameterTypes().length == 1 && emfClass.isAssignableFrom(c.getParameterTypes()[0])) {
379                                            log.debug("Created default JPA persistence manager");
380                                            return (TidePersistenceManager)c.newInstance(emf);
381                                    }
382                                }
383                        }
384                        
385                        throw new RuntimeException("Default Tide persistence manager not found");
386                }
387                catch (ServiceException e) {
388                    if (entityManagerFactoryBeanName != null)
389                            log.debug("EntityManagerFactory named %s not found, JPA support disabled", emfBeanName);
390                    
391                    return null;
392                }
393                catch (Exception e) {
394                    throw new RuntimeException("Could not create default Tide persistence manager", e);
395                }
396            }
397            
398            return (TidePersistenceManager)findComponent(persistenceManagerBeanName, null);
399        }
400    }