001    /*
002      GRANITE DATA SERVICES
003      Copyright (C) 2011 GRANITE DATA SERVICES S.A.S.
004    
005      This file is part of Granite Data Services.
006    
007      Granite Data Services is free software; you can redistribute it and/or modify
008      it under the terms of the GNU Library General Public License as published by
009      the Free Software Foundation; either version 2 of the License, or (at your
010      option) any later version.
011    
012      Granite Data Services is distributed in the hope that it will be useful, but
013      WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
014      FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License
015      for more details.
016    
017      You should have received a copy of the GNU Library General Public License
018      along with this library; if not, see <http://www.gnu.org/licenses/>.
019    */
020    
021    package org.granite.tide.spring;
022    
023    import java.lang.annotation.Annotation;
024    import java.lang.reflect.Constructor;
025    import java.lang.reflect.InvocationTargetException;
026    import java.lang.reflect.Method;
027    import java.util.HashSet;
028    import java.util.List;
029    import java.util.Map;
030    import java.util.Set;
031    
032    import javax.servlet.ServletContext;
033    
034    import org.granite.context.GraniteContext;
035    import org.granite.logging.Logger;
036    import org.granite.messaging.service.ServiceException;
037    import org.granite.messaging.service.ServiceInvocationContext;
038    import org.granite.messaging.webapp.HttpGraniteContext;
039    import org.granite.tide.IInvocationCall;
040    import org.granite.tide.IInvocationResult;
041    import org.granite.tide.TidePersistenceManager;
042    import org.granite.tide.TideServiceContext;
043    import org.granite.tide.TideTransactionManager;
044    import org.granite.tide.annotations.BypassTideMerge;
045    import org.granite.tide.async.AsyncPublisher;
046    import org.granite.tide.data.DataContext;
047    import org.granite.tide.data.DataUpdatePostprocessor;
048    import org.granite.tide.invocation.ContextUpdate;
049    import org.granite.tide.invocation.InvocationResult;
050    import org.granite.util.TypeUtil;
051    import org.springframework.aop.framework.Advised;
052    import org.springframework.aop.support.AopUtils;
053    import org.springframework.beans.BeansException;
054    import org.springframework.beans.factory.NoSuchBeanDefinitionException;
055    import org.springframework.context.ApplicationContext;
056    import org.springframework.context.ApplicationContextAware;
057    import org.springframework.orm.jpa.EntityManagerFactoryInfo;
058    import org.springframework.transaction.PlatformTransactionManager;
059    import org.springframework.web.context.support.WebApplicationContextUtils;
060    
061    
062    /**
063     *  @author Sebastien Deleuze
064     *      @author William DRAI
065     */
066    public class SpringServiceContext extends TideServiceContext implements ApplicationContextAware {
067    
068        private static final long serialVersionUID = 1L;
069        
070        protected transient ApplicationContext springContext = null;
071        
072        private String persistenceManagerBeanName = null;
073        private String entityManagerFactoryBeanName = null;
074        
075        private static final Logger log = Logger.getLogger(SpringServiceContext.class);
076                    
077        public SpringServiceContext() throws ServiceException {
078            super();
079            
080            log.debug("Getting spring context from container");
081            getSpringContext();
082        }
083        
084        public SpringServiceContext(ApplicationContext springContext) throws ServiceException {
085            super();
086            
087            this.springContext = springContext;
088        }
089        
090        public void setApplicationContext(ApplicationContext springContext) {
091            this.springContext = springContext;
092        }
093    
094        protected ApplicationContext getSpringContext() {
095            if (springContext == null) {
096                GraniteContext context = GraniteContext.getCurrentInstance();
097                ServletContext sc = ((HttpGraniteContext)context).getServletContext();
098                springContext = WebApplicationContextUtils.getRequiredWebApplicationContext(sc);
099            }
100            return springContext;           
101        }
102        
103        
104        @Override
105        protected AsyncPublisher getAsyncPublisher() {
106            return null;
107        }    
108        
109        @Override
110        public Object findComponent(String componentName, Class<?> componentClass) {
111            Object bean = null;
112            String key = COMPONENT_ATTR + (componentName != null ? componentName : "_CLASS_" + componentClass.getName());
113            
114            GraniteContext context = GraniteContext.getCurrentInstance();
115            if (context != null) {
116                    bean = context.getRequestMap().get(key);
117                    if (bean != null)
118                            return bean;
119            }
120            
121            ApplicationContext springContext = getSpringContext();
122            try {
123                    if (componentClass != null) {
124                            Map<String, ?> beans = springContext.getBeansOfType(componentClass);
125                            if (beans.size() == 1)
126                                    bean = beans.values().iterator().next();
127                            else if (beans.size() > 1 && componentName != null && !("".equals(componentName))) {
128                                    if (beans.containsKey(componentName))
129                                            bean = beans.get(componentName);
130                            }
131                            else if (beans.isEmpty() && springContext.getClass().getName().indexOf("Grails") > 0 && componentClass.getName().endsWith("Service")) {
132                                    try {
133                                            Object serviceClass = springContext.getBean(componentClass.getName() + "ServiceClass");                         
134                                            Method m = serviceClass.getClass().getMethod("getPropertyName");
135                                            String compName = (String)m.invoke(serviceClass);
136                                            bean = springContext.getBean(compName);
137                                    }
138                                    catch (NoSuchMethodException e) {
139                                            log.error(e, "Could not get service class for %s", componentClass.getName());
140                                    }
141                                    catch (InvocationTargetException e) {
142                                            log.error(e.getCause(), "Could not get service class for %s", componentClass.getName());
143                                    }
144                                    catch (IllegalAccessException e) {
145                                            log.error(e.getCause(), "Could not get service class for %s", componentClass.getName());
146                                    }
147                            }
148                    }
149                    if (bean == null && componentName != null && !("".equals(componentName)))
150                            bean = springContext.getBean(componentName);
151                    
152                if (context != null)
153                    context.getRequestMap().put(key, bean);
154                return bean;
155            }
156            catch (NoSuchBeanDefinitionException nexc) {
157                    if (componentName != null && componentName.endsWith("Controller")) {
158                            try {
159                                    int idx = componentName.lastIndexOf(".");
160                                    String controllerName = idx > 0 
161                                            ? componentName.substring(0, idx+1) + componentName.substring(idx+1, idx+2).toUpperCase() + componentName.substring(idx+2)
162                                            : componentName.substring(0, 1).toUpperCase() + componentName.substring(1);
163                                    bean = getSpringContext().getBean(controllerName);
164                        if (context != null)
165                            context.getRequestMap().put(key, bean);
166                                    return bean;
167                            }
168                    catch (NoSuchBeanDefinitionException nexc2) {
169                    }
170                    }
171                    
172                String msg = "Spring service named '" + componentName + "' does not exist.";
173                ServiceException e = new ServiceException(msg, nexc);
174                throw e;
175            } 
176            catch (BeansException bexc) {
177                String msg = "Unable to create Spring service named '" + componentName + "'";
178                ServiceException e = new ServiceException(msg, bexc);
179                throw e;
180            }    
181        }
182        
183        @Override
184        @SuppressWarnings("unchecked")
185        public Set<Class<?>> findComponentClasses(String componentName, Class<?> componentClass) {
186            String key = COMPONENT_CLASS_ATTR + componentName;
187            Set<Class<?>> classes = null; 
188            GraniteContext context = GraniteContext.getCurrentInstance();
189            if (context != null) {
190                    classes = (Set<Class<?>>)context.getRequestMap().get(key);
191                    if (classes != null)
192                            return classes;
193            }
194            
195            Object bean = findComponent(componentName, componentClass);
196            classes = buildComponentClasses(bean);        
197            if (classes == null)
198                    return null;
199            
200            if (context != null)
201                    context.getRequestMap().put(key, classes);
202            return classes;
203        }
204        
205        protected Set<Class<?>> buildComponentClasses(Object bean) {
206            Set<Class<?>> classes = new HashSet<Class<?>>();
207            for (Class<?> i : bean.getClass().getInterfaces())
208                    classes.add(i);
209            
210            try {
211                    while (bean instanceof Advised)
212                            bean = ((Advised)bean).getTargetSource().getTarget();
213                    
214                    classes.add(AopUtils.getTargetClass(bean));
215            }
216            catch (Exception e) {
217                log.warn(e, "Could not get AOP class for component " + bean.getClass());
218                    return null;
219            }
220            
221            return classes;
222        }
223        
224        protected boolean isBeanAnnotationPresent(Object bean, Class<? extends Annotation> annotationClass) {
225            if (bean.getClass().isAnnotationPresent(annotationClass))
226                    return true;
227            
228            try {
229                    while (bean instanceof Advised)
230                            bean = ((Advised)bean).getTargetSource().getTarget();
231                    
232                    if (AopUtils.getTargetClass(bean).isAnnotationPresent(annotationClass))
233                            return true;
234            }
235            catch (Exception e) {
236                log.warn(e, "Could not get AOP class for component " + bean.getClass());
237            }
238            
239            return false;
240        }
241        
242        protected boolean isBeanMethodAnnotationPresent(Object bean, String methodName, Class<?>[] methodArgTypes, Class<? extends Annotation> annotationClass) {
243            try {
244                    Method m = bean.getClass().getMethod(methodName, methodArgTypes);
245                    if (m.isAnnotationPresent(annotationClass))
246                            return true;
247                    
248                    while (bean instanceof Advised)
249                            bean = ((Advised)bean).getTargetSource().getTarget();
250                    
251                    m = AopUtils.getTargetClass(bean).getMethod(methodName, methodArgTypes);
252                    if (m.isAnnotationPresent(annotationClass))
253                            return true;
254                    }
255            catch (Exception e) {
256                    log.warn("Could not find bean method", e);
257            }
258            
259            return false;
260        }
261    
262        
263        @Override
264        public void prepareCall(ServiceInvocationContext context, IInvocationCall c, String componentName, Class<?> componentClass) {
265            DataContext.init();
266            
267            DataUpdatePostprocessor dupp = (DataUpdatePostprocessor)findComponent(null, DataUpdatePostprocessor.class);
268            if (dupp != null)
269                    DataContext.get().setDataUpdatePostprocessor(dupp);
270        }
271    
272        @Override
273        public IInvocationResult postCall(ServiceInvocationContext context, Object result, String componentName, Class<?> componentClass) {
274                    List<ContextUpdate> results = null;
275            DataContext dataContext = DataContext.get();
276                    Object[][] updates = dataContext != null ? dataContext.getUpdates() : null;
277                    
278            InvocationResult ires = new InvocationResult(result, results);
279            if (isBeanAnnotationPresent(context.getBean(), BypassTideMerge.class))
280                    ires.setMerge(false);
281            else if (isBeanMethodAnnotationPresent(context.getBean(), context.getMethod().getName(), context.getMethod().getParameterTypes(), BypassTideMerge.class))
282                            ires.setMerge(false);
283            
284            ires.setUpdates(updates);
285            
286            return ires;
287        }
288    
289        @Override
290        public void postCallFault(ServiceInvocationContext context, Throwable t, String componentName, Class<?> componentClass) {        
291        }
292        
293        
294        public void setEntityManagerFactoryBeanName(String beanName) {
295            this.entityManagerFactoryBeanName = beanName;
296        }
297        
298        public void setPersistenceManagerBeanName(String beanName) {
299            this.persistenceManagerBeanName = beanName;
300        }
301        
302        /**
303         *  Create a TidePersistenceManager
304         *  
305         *  @param create create if not existent (can be false for use in entity merge)
306         *  @return a PersistenceContextManager
307         */
308        @Override
309        protected TidePersistenceManager getTidePersistenceManager(boolean create) {
310            if (!create)
311                return null;
312            
313            TidePersistenceManager pm = (TidePersistenceManager)GraniteContext.getCurrentInstance().getRequestMap().get(TidePersistenceManager.class.getName());
314            if (pm != null)
315                    return pm;
316            
317            pm = createPersistenceManager();
318            GraniteContext.getCurrentInstance().getRequestMap().put(TidePersistenceManager.class.getName(), pm);
319            return pm;
320        }
321        
322        private TidePersistenceManager createPersistenceManager() {
323            if (persistenceManagerBeanName == null) {
324                    if (entityManagerFactoryBeanName == null) {
325                            // No bean or entity manager factory specified 
326                            
327                            // 1. Look for a TidePersistenceManager bean
328                            Map<String, ?> pms = springContext.getBeansOfType(TidePersistenceManager.class);
329                            if (pms.size() > 1)
330                                    throw new RuntimeException("More than one Tide persistence managers defined");
331                            
332                            if (pms.size() == 1)
333                                    return (TidePersistenceManager)pms.values().iterator().next();
334                            
335                            // 2. If not found, try to determine the Spring transaction manager                     
336                            Map<String, ?> tms = springContext.getBeansOfType(PlatformTransactionManager.class);
337                            if (tms.isEmpty())
338                                    log.debug("No Spring transaction manager found, specify a persistence-manager-bean-name or entity-manager-factory-bean-name");
339                            else if (tms.size() > 1)
340                                    log.debug("More than one Spring transaction manager found, specify a persistence-manager-bean-name or entity-manager-factory-bean-name");
341                            else if (tms.size() == 1) {
342                                    PlatformTransactionManager ptm = (PlatformTransactionManager)tms.values().iterator().next();
343                                    
344                                    try {
345                                            // Check if a JPA factory is setup
346                                            // If we find one, define a persistence manager with the JPA factory and Spring transaction manager
347                                                    Class<?> emfiClass = TypeUtil.forName("org.springframework.orm.jpa.EntityManagerFactoryInfo");
348                                            Map<String, ?> emfs = springContext.getBeansOfType(emfiClass);
349                                                    if (emfs.size() == 1) {
350                                                            try {
351                                                                    Class<?> emfClass = TypeUtil.forName("javax.persistence.EntityManagerFactory");
352                                                        Class<?> pcmClass = TypeUtil.forName("org.granite.tide.data.JPAPersistenceManager");
353                                                        Constructor<?>[] cs = pcmClass.getConstructors();
354                                                    for (Constructor<?> c : cs) {
355                                                            if (c.getParameterTypes().length == 2 && emfClass.isAssignableFrom(c.getParameterTypes()[0])
356                                                                    && TideTransactionManager.class.isAssignableFrom(c.getParameterTypes()[1])) {
357                                                                    log.debug("Created JPA persistence manager with Spring transaction manager");
358                                                                            TideTransactionManager tm = new SpringTransactionManager(ptm);
359                                                                    return (TidePersistenceManager)c.newInstance(((EntityManagerFactoryInfo)emfs.values().iterator().next()).getNativeEntityManagerFactory(), tm);
360                                                            }
361                                                    }
362                                                            }
363                                                            catch (Exception e) {
364                                                                    log.error(e, "Could not setup persistence manager for JPA " + emfs.keySet().iterator().next());
365                                                            }
366                                                    }
367                                    }
368                                            catch (ClassNotFoundException e) {
369                                                    // Ignore: JPA not present on classpath
370                                            }
371                                            catch (NoClassDefFoundError e) {
372                                                    // Ignore: JPA not present on classpath
373                                            }
374                                            catch (Exception e) {
375                                                    log.error("Could not lookup EntityManagerFactoryInfo", e);
376                                            }
377                                            
378                                    // If no entity manager, we define a Spring persistence manager 
379                                            // that will try to infer persistence info from the Spring transaction manager
380                                            return new SpringPersistenceManager(ptm);
381                            }
382                    }
383                    
384                String emfBeanName = entityManagerFactoryBeanName != null ? entityManagerFactoryBeanName : "entityManagerFactory";
385                try {
386                    // Lookup the specified entity manager factory
387                    Object emf = findComponent(emfBeanName, null);
388                    
389                    // Try to determine the Spring transaction manager
390                    TideTransactionManager tm = null;
391                            Map<String, ?> ptms = springContext.getBeansOfType(PlatformTransactionManager.class);
392                            if (ptms.size() == 1) {
393                                    log.debug("Found Spring transaction manager " + ptms.keySet().iterator().next());
394                                    tm = new SpringTransactionManager((PlatformTransactionManager)ptms.values().iterator().next());
395                            }
396                    
397                                    Class<?> emfClass = TypeUtil.forName("javax.persistence.EntityManagerFactory");
398                        Class<?> pcmClass = TypeUtil.forName("org.granite.tide.data.JPAPersistenceManager");
399                        Constructor<?>[] cs = pcmClass.getConstructors();
400                        if (tm != null) {
401                            for (Constructor<?> c : cs) {
402                                    if (c.getParameterTypes().length == 2 && emfClass.isAssignableFrom(c.getParameterTypes()[0])
403                                            && TideTransactionManager.class.isAssignableFrom(c.getParameterTypes()[1])) {
404                                            log.debug("Created JPA persistence manager with Spring transaction manager");
405                                            return (TidePersistenceManager)c.newInstance(((EntityManagerFactoryInfo)emf).getNativeEntityManagerFactory(), tm);
406                                    }
407                            }
408                        }
409                        else {
410                                for (Constructor<?> c : cs) {
411                                    if (c.getParameterTypes().length == 1 && emfClass.isAssignableFrom(c.getParameterTypes()[0])) {
412                                            log.debug("Created default JPA persistence manager");
413                                            return (TidePersistenceManager)c.newInstance(emf);
414                                    }
415                                }
416                        }
417                        
418                        throw new RuntimeException("Default Tide persistence manager not found");
419                }
420                catch (ServiceException e) {
421                    if (entityManagerFactoryBeanName != null)
422                            log.debug("EntityManagerFactory named %s not found, JPA support disabled", emfBeanName);
423                    
424                    return null;
425                }
426                catch (Exception e) {
427                    throw new RuntimeException("Could not create default Tide persistence manager", e);
428                }
429            }
430            
431            return (TidePersistenceManager)findComponent(persistenceManagerBeanName, null);
432        }
433    }