001    /*
002      GRANITE DATA SERVICES
003      Copyright (C) 2011 GRANITE DATA SERVICES S.A.S.
004    
005      This file is part of Granite Data Services.
006    
007      Granite Data Services is free software; you can redistribute it and/or modify
008      it under the terms of the GNU Library General Public License as published by
009      the Free Software Foundation; either version 2 of the License, or (at your
010      option) any later version.
011    
012      Granite Data Services is distributed in the hope that it will be useful, but
013      WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
014      FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License
015      for more details.
016    
017      You should have received a copy of the GNU Library General Public License
018      along with this library; if not, see <http://www.gnu.org/licenses/>.
019    */
020    
021    package org.granite.tide.spring;
022    
023    import java.lang.reflect.Constructor;
024    import java.lang.reflect.InvocationTargetException;
025    import java.lang.reflect.Method;
026    import java.util.HashSet;
027    import java.util.List;
028    import java.util.Map;
029    import java.util.Set;
030    
031    import javax.servlet.ServletContext;
032    
033    import org.granite.context.GraniteContext;
034    import org.granite.logging.Logger;
035    import org.granite.messaging.service.ServiceException;
036    import org.granite.messaging.service.ServiceInvocationContext;
037    import org.granite.messaging.webapp.HttpGraniteContext;
038    import org.granite.tide.IInvocationCall;
039    import org.granite.tide.IInvocationResult;
040    import org.granite.tide.TidePersistenceManager;
041    import org.granite.tide.TideServiceContext;
042    import org.granite.tide.TideTransactionManager;
043    import org.granite.tide.annotations.BypassTideMerge;
044    import org.granite.tide.async.AsyncPublisher;
045    import org.granite.tide.data.DataContext;
046    import org.granite.tide.data.DataUpdatePostprocessor;
047    import org.granite.tide.invocation.ContextUpdate;
048    import org.granite.tide.invocation.InvocationResult;
049    import org.granite.util.ClassUtil;
050    import org.springframework.aop.framework.Advised;
051    import org.springframework.aop.support.AopUtils;
052    import org.springframework.beans.BeansException;
053    import org.springframework.beans.factory.NoSuchBeanDefinitionException;
054    import org.springframework.context.ApplicationContext;
055    import org.springframework.context.ApplicationContextAware;
056    import org.springframework.orm.jpa.EntityManagerFactoryInfo;
057    import org.springframework.transaction.PlatformTransactionManager;
058    import org.springframework.web.context.support.WebApplicationContextUtils;
059    
060    
061    /**
062     *  @author Sebastien Deleuze
063     *      @author William DRAI
064     */
065    public class SpringServiceContext extends TideServiceContext implements ApplicationContextAware {
066    
067        private static final long serialVersionUID = 1L;
068        
069        protected transient ApplicationContext springContext = null;
070        
071        private String persistenceManagerBeanName = null;
072        private String entityManagerFactoryBeanName = null;
073        
074        private static final Logger log = Logger.getLogger(SpringServiceContext.class);
075                    
076        public SpringServiceContext() throws ServiceException {
077            super();
078            
079            log.debug("Getting spring context from container");
080            getSpringContext();
081        }
082        
083        public SpringServiceContext(ApplicationContext springContext) throws ServiceException {
084            super();
085            
086            this.springContext = springContext;
087        }
088        
089        public void setApplicationContext(ApplicationContext springContext) {
090            this.springContext = springContext;
091        }
092    
093        protected ApplicationContext getSpringContext() {
094            if (springContext == null) {
095                GraniteContext context = GraniteContext.getCurrentInstance();
096                ServletContext sc = ((HttpGraniteContext)context).getServletContext();
097                springContext = WebApplicationContextUtils.getRequiredWebApplicationContext(sc);
098            }
099            return springContext;           
100        }
101        
102        
103        @Override
104        protected AsyncPublisher getAsyncPublisher() {
105            return null;
106        }    
107        
108        @Override
109        public Object findComponent(String componentName, Class<?> componentClass) {
110            Object bean = null;
111            String key = COMPONENT_ATTR + (componentName != null ? componentName : "_CLASS_" + componentClass.getName());
112            
113            GraniteContext context = GraniteContext.getCurrentInstance();
114            if (context != null) {
115                    bean = context.getRequestMap().get(key);
116                    if (bean != null)
117                            return bean;
118            }
119            
120            ApplicationContext springContext = getSpringContext();
121            try {
122                    if (componentClass != null) {
123                            Map<String, ?> beans = springContext.getBeansOfType(componentClass);
124                            if (beans.size() == 1)
125                                    bean = beans.values().iterator().next();
126                            else if (beans.size() > 1 && componentName != null && !("".equals(componentName))) {
127                                    if (beans.containsKey(componentName))
128                                            bean = beans.get(componentName);
129                            }
130                            else if (beans.isEmpty() && springContext.getClass().getName().indexOf("Grails") > 0 && componentClass.getName().endsWith("Service")) {
131                                    try {
132                                            Object serviceClass = springContext.getBean(componentClass.getName() + "ServiceClass");                         
133                                            Method m = serviceClass.getClass().getMethod("getPropertyName");
134                                            String compName = (String)m.invoke(serviceClass);
135                                            bean = springContext.getBean(compName);
136                                    }
137                                    catch (NoSuchMethodException e) {
138                                            log.error(e, "Could not get service class for %s", componentClass.getName());
139                                    }
140                                    catch (InvocationTargetException e) {
141                                            log.error(e.getCause(), "Could not get service class for %s", componentClass.getName());
142                                    }
143                                    catch (IllegalAccessException e) {
144                                            log.error(e.getCause(), "Could not get service class for %s", componentClass.getName());
145                                    }
146                            }
147                    }
148                    if (bean == null && componentName != null && !("".equals(componentName)))
149                            bean = springContext.getBean(componentName);
150                    
151                if (context != null)
152                    context.getRequestMap().put(key, bean);
153                return bean;
154            }
155            catch (NoSuchBeanDefinitionException nexc) {
156                    if (componentName != null && componentName.endsWith("Controller")) {
157                            try {
158                                    int idx = componentName.lastIndexOf(".");
159                                    String controllerName = idx > 0 
160                                            ? componentName.substring(0, idx+1) + componentName.substring(idx+1, idx+2).toUpperCase() + componentName.substring(idx+2)
161                                            : componentName.substring(0, 1).toUpperCase() + componentName.substring(1);
162                                    bean = getSpringContext().getBean(controllerName);
163                        if (context != null)
164                            context.getRequestMap().put(key, bean);
165                                    return bean;
166                            }
167                    catch (NoSuchBeanDefinitionException nexc2) {
168                    }
169                    }
170                    
171                String msg = "Spring service named '" + componentName + "' does not exist.";
172                ServiceException e = new ServiceException(msg, nexc);
173                throw e;
174            } 
175            catch (BeansException bexc) {
176                String msg = "Unable to create Spring service named '" + componentName + "'";
177                ServiceException e = new ServiceException(msg, bexc);
178                throw e;
179            }    
180        }
181        
182        @Override
183        @SuppressWarnings("unchecked")
184        public Set<Class<?>> findComponentClasses(String componentName, Class<?> componentClass) {
185            String key = COMPONENT_CLASS_ATTR + componentName;
186            Set<Class<?>> classes = null; 
187            GraniteContext context = GraniteContext.getCurrentInstance();
188            if (context != null) {
189                    classes = (Set<Class<?>>)context.getRequestMap().get(key);
190                    if (classes != null)
191                            return classes;
192            }
193            
194            try {
195                Object bean = findComponent(componentName, componentClass);
196                classes = new HashSet<Class<?>>();
197                for (Class<?> i : bean.getClass().getInterfaces())
198                    classes.add(i);
199                
200                while (bean instanceof Advised)
201                    bean = ((Advised)bean).getTargetSource().getTarget();
202                
203                classes.add(AopUtils.getTargetClass(bean));
204                
205                if (context != null)
206                    context.getRequestMap().put(key, classes);
207                return classes;
208            }
209            catch (Exception e) {
210                log.warn(e, "Could not get class for component " + componentName);
211                return null;
212            }
213        }
214    
215        
216        @Override
217        public void prepareCall(ServiceInvocationContext context, IInvocationCall c, String componentName, Class<?> componentClass) {
218            DataContext.init();
219            
220            DataUpdatePostprocessor dupp = (DataUpdatePostprocessor)findComponent(null, DataUpdatePostprocessor.class);
221            if (dupp != null)
222                    DataContext.get().setDataUpdatePostprocessor(dupp);
223        }
224    
225        @Override
226        public IInvocationResult postCall(ServiceInvocationContext context, Object result, String componentName, Class<?> componentClass) {
227                    List<ContextUpdate> results = null;
228            DataContext dataContext = DataContext.get();
229                    Object[][] updates = dataContext != null ? dataContext.getUpdates() : null;
230                    
231            InvocationResult ires = new InvocationResult(result, results);
232            if (context.getBean().getClass().isAnnotationPresent(BypassTideMerge.class))
233                    ires.setMerge(false);
234            else if (context.getMethod().isAnnotationPresent(BypassTideMerge.class))
235                            ires.setMerge(false);
236            
237            ires.setUpdates(updates);
238            
239            return ires;
240        }
241    
242        @Override
243        public void postCallFault(ServiceInvocationContext context, Throwable t, String componentName, Class<?> componentClass) {        
244        }
245        
246        
247        public void setEntityManagerFactoryBeanName(String beanName) {
248            this.entityManagerFactoryBeanName = beanName;
249        }
250        
251        public void setPersistenceManagerBeanName(String beanName) {
252            this.persistenceManagerBeanName = beanName;
253        }
254        
255        /**
256         *  Create a TidePersistenceManager
257         *  
258         *  @param create create if not existent (can be false for use in entity merge)
259         *  @return a PersistenceContextManager
260         */
261        @Override
262        protected TidePersistenceManager getTidePersistenceManager(boolean create) {
263            if (!create)
264                return null;
265            
266            TidePersistenceManager pm = (TidePersistenceManager)GraniteContext.getCurrentInstance().getRequestMap().get(TidePersistenceManager.class.getName());
267            if (pm != null)
268                    return pm;
269            
270            pm = createPersistenceManager();
271            GraniteContext.getCurrentInstance().getRequestMap().put(TidePersistenceManager.class.getName(), pm);
272            return pm;
273        }
274        
275        private TidePersistenceManager createPersistenceManager() {
276            if (persistenceManagerBeanName == null) {
277                    if (entityManagerFactoryBeanName == null) {
278                            // No bean or entity manager factory specified 
279                            
280                            // 1. Look for a TidePersistenceManager bean
281                            Map<String, ?> pms = springContext.getBeansOfType(TidePersistenceManager.class);
282                            if (pms.size() > 1)
283                                    throw new RuntimeException("More than one Tide persistence managers defined");
284                            
285                            if (pms.size() == 1)
286                                    return (TidePersistenceManager)pms.values().iterator().next();
287                            
288                            // 2. If not found, try to determine the Spring transaction manager                     
289                            Map<String, ?> tms = springContext.getBeansOfType(PlatformTransactionManager.class);
290                            if (tms.isEmpty())
291                                    log.debug("No Spring transaction manager found, specify a persistence-manager-bean-name or entity-manager-factory-bean-name");
292                            else if (tms.size() > 1)
293                                    log.debug("More than one Spring transaction manager found, specify a persistence-manager-bean-name or entity-manager-factory-bean-name");
294                            else if (tms.size() == 1) {
295                                    PlatformTransactionManager ptm = (PlatformTransactionManager)tms.values().iterator().next();
296                                    
297                                    try {
298                                            // Check if a JPA factory is setup
299                                            // If we find one, define a persistence manager with the JPA factory and Spring transaction manager
300                                                    Class<?> emfiClass = ClassUtil.forName("org.springframework.orm.jpa.EntityManagerFactoryInfo");
301                                            Map<String, ?> emfs = springContext.getBeansOfType(emfiClass);
302                                                    if (emfs.size() == 1) {
303                                                            try {
304                                                                    Class<?> emfClass = ClassUtil.forName("javax.persistence.EntityManagerFactory");
305                                                        Class<?> pcmClass = ClassUtil.forName("org.granite.tide.data.JPAPersistenceManager");
306                                                        Constructor<?>[] cs = pcmClass.getConstructors();
307                                                    for (Constructor<?> c : cs) {
308                                                            if (c.getParameterTypes().length == 2 && emfClass.isAssignableFrom(c.getParameterTypes()[0])
309                                                                    && TideTransactionManager.class.isAssignableFrom(c.getParameterTypes()[1])) {
310                                                                    log.debug("Created JPA persistence manager with Spring transaction manager");
311                                                                            TideTransactionManager tm = new SpringTransactionManager(ptm);
312                                                                    return (TidePersistenceManager)c.newInstance(((EntityManagerFactoryInfo)emfs.values().iterator().next()).getNativeEntityManagerFactory(), tm);
313                                                            }
314                                                    }
315                                                            }
316                                                            catch (Exception e) {
317                                                                    log.error(e, "Could not setup persistence manager for JPA " + emfs.keySet().iterator().next());
318                                                            }
319                                                    }
320                                    }
321                                            catch (ClassNotFoundException e) {
322                                                    // Ignore: JPA not present on classpath
323                                            }
324                                            catch (NoClassDefFoundError e) {
325                                                    // Ignore: JPA not present on classpath
326                                            }
327                                            catch (Exception e) {
328                                                    log.error("Could not lookup EntityManagerFactoryInfo", e);
329                                            }
330                                            
331                                    // If no entity manager, we define a Spring persistence manager 
332                                            // that will try to infer persistence info from the Spring transaction manager
333                                            return new SpringPersistenceManager(ptm);
334                            }
335                    }
336                    
337                String emfBeanName = entityManagerFactoryBeanName != null ? entityManagerFactoryBeanName : "entityManagerFactory";
338                try {
339                    // Lookup the specified entity manager factory
340                    Object emf = findComponent(emfBeanName, null);
341                    
342                    // Try to determine the Spring transaction manager
343                    TideTransactionManager tm = null;
344                            Map<String, ?> ptms = springContext.getBeansOfType(PlatformTransactionManager.class);
345                            if (ptms.size() == 1) {
346                                    log.debug("Found Spring transaction manager " + ptms.keySet().iterator().next());
347                                    tm = new SpringTransactionManager((PlatformTransactionManager)ptms.values().iterator().next());
348                            }
349                    
350                                    Class<?> emfClass = ClassUtil.forName("javax.persistence.EntityManagerFactory");
351                        Class<?> pcmClass = ClassUtil.forName("org.granite.tide.data.JPAPersistenceManager");
352                        Constructor<?>[] cs = pcmClass.getConstructors();
353                        if (tm != null) {
354                            for (Constructor<?> c : cs) {
355                                    if (c.getParameterTypes().length == 2 && emfClass.isAssignableFrom(c.getParameterTypes()[0])
356                                            && TideTransactionManager.class.isAssignableFrom(c.getParameterTypes()[1])) {
357                                            log.debug("Created JPA persistence manager with Spring transaction manager");
358                                            return (TidePersistenceManager)c.newInstance(((EntityManagerFactoryInfo)emf).getNativeEntityManagerFactory(), tm);
359                                    }
360                            }
361                        }
362                        else {
363                                for (Constructor<?> c : cs) {
364                                    if (c.getParameterTypes().length == 1 && emfClass.isAssignableFrom(c.getParameterTypes()[0])) {
365                                            log.debug("Created default JPA persistence manager");
366                                            return (TidePersistenceManager)c.newInstance(emf);
367                                    }
368                                }
369                        }
370                        
371                        throw new RuntimeException("Default Tide persistence manager not found");
372                }
373                catch (ServiceException e) {
374                    if (entityManagerFactoryBeanName != null)
375                            log.debug("EntityManagerFactory named %s not found, JPA support disabled", emfBeanName);
376                    
377                    return null;
378                }
379                catch (Exception e) {
380                    throw new RuntimeException("Could not create default Tide persistence manager", e);
381                }
382            }
383            
384            return (TidePersistenceManager)findComponent(persistenceManagerBeanName, null);
385        }
386    }