001/*
002  GRANITE DATA SERVICES
003  Copyright (C) 2011 GRANITE DATA SERVICES S.A.S.
004
005  This file is part of Granite Data Services.
006
007  Granite Data Services is free software; you can redistribute it and/or modify
008  it under the terms of the GNU Library General Public License as published by
009  the Free Software Foundation; either version 2 of the License, or (at your
010  option) any later version.
011
012  Granite Data Services is distributed in the hope that it will be useful, but
013  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
014  FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License
015  for more details.
016
017  You should have received a copy of the GNU Library General Public License
018  along with this library; if not, see <http://www.gnu.org/licenses/>.
019*/
020
021package org.granite.tide.spring;
022
023import java.lang.annotation.Annotation;
024import java.lang.reflect.Constructor;
025import java.lang.reflect.InvocationTargetException;
026import java.lang.reflect.Method;
027import java.util.HashSet;
028import java.util.List;
029import java.util.Map;
030import java.util.Set;
031
032import javax.servlet.ServletContext;
033
034import org.granite.context.GraniteContext;
035import org.granite.logging.Logger;
036import org.granite.messaging.service.ServiceException;
037import org.granite.messaging.service.ServiceInvocationContext;
038import org.granite.messaging.webapp.HttpGraniteContext;
039import org.granite.tide.IInvocationCall;
040import org.granite.tide.IInvocationResult;
041import org.granite.tide.TidePersistenceManager;
042import org.granite.tide.TideServiceContext;
043import org.granite.tide.TideTransactionManager;
044import org.granite.tide.annotations.BypassTideMerge;
045import org.granite.tide.async.AsyncPublisher;
046import org.granite.tide.data.DataContext;
047import org.granite.tide.data.DataUpdatePostprocessor;
048import org.granite.tide.invocation.ContextUpdate;
049import org.granite.tide.invocation.InvocationResult;
050import org.granite.util.TypeUtil;
051import org.springframework.aop.framework.Advised;
052import org.springframework.aop.support.AopUtils;
053import org.springframework.beans.BeansException;
054import org.springframework.beans.factory.NoSuchBeanDefinitionException;
055import org.springframework.context.ApplicationContext;
056import org.springframework.context.ApplicationContextAware;
057import org.springframework.orm.jpa.EntityManagerFactoryInfo;
058import org.springframework.transaction.PlatformTransactionManager;
059import org.springframework.web.context.support.WebApplicationContextUtils;
060
061
062/**
063 *  @author Sebastien Deleuze
064 *      @author William DRAI
065 */
066public class SpringServiceContext extends TideServiceContext implements ApplicationContextAware {
067
068    private static final long serialVersionUID = 1L;
069    
070    protected transient ApplicationContext springContext = null;
071    
072    private String persistenceManagerBeanName = null;
073    private String entityManagerFactoryBeanName = null;
074    
075    private static final Logger log = Logger.getLogger(SpringServiceContext.class);
076                
077    public SpringServiceContext() throws ServiceException {
078        super();
079        
080        log.debug("Getting spring context from container");
081        getSpringContext();
082    }
083    
084    public SpringServiceContext(ApplicationContext springContext) throws ServiceException {
085        super();
086        
087        this.springContext = springContext;
088    }
089    
090    public void setApplicationContext(ApplicationContext springContext) {
091        this.springContext = springContext;
092    }
093
094    protected ApplicationContext getSpringContext() {
095        if (springContext == null) {
096            GraniteContext context = GraniteContext.getCurrentInstance();
097            ServletContext sc = ((HttpGraniteContext)context).getServletContext();
098            springContext = WebApplicationContextUtils.getRequiredWebApplicationContext(sc);
099        }
100        return springContext;           
101    }
102    
103    
104    @Override
105    protected AsyncPublisher getAsyncPublisher() {
106        return null;
107    }    
108    
109    @Override
110    public Object findComponent(String componentName, Class<?> componentClass) {
111        Object bean = null;
112        String key = COMPONENT_ATTR + (componentName != null ? componentName : "_CLASS_" + componentClass.getName());
113        
114        GraniteContext context = GraniteContext.getCurrentInstance();
115        if (context != null) {
116                bean = context.getRequestMap().get(key);
117                if (bean != null)
118                        return bean;
119        }
120        
121        ApplicationContext springContext = getSpringContext();
122        try {
123                if (componentClass != null) {
124                        Map<String, ?> beans = springContext.getBeansOfType(componentClass);
125                        if (beans.size() == 1)
126                                bean = beans.values().iterator().next();
127                        else if (beans.size() > 1 && componentName != null && !("".equals(componentName))) {
128                                if (beans.containsKey(componentName))
129                                        bean = beans.get(componentName);
130                        }
131                        else if (beans.isEmpty() && springContext.getClass().getName().indexOf("Grails") > 0 && componentClass.getName().endsWith("Service")) {
132                                try {
133                                        Object serviceClass = springContext.getBean(componentClass.getName() + "ServiceClass");                         
134                                        Method m = serviceClass.getClass().getMethod("getPropertyName");
135                                        String compName = (String)m.invoke(serviceClass);
136                                        bean = springContext.getBean(compName);
137                                }
138                                catch (NoSuchMethodException e) {
139                                        log.error(e, "Could not get service class for %s", componentClass.getName());
140                                }
141                                catch (InvocationTargetException e) {
142                                        log.error(e.getCause(), "Could not get service class for %s", componentClass.getName());
143                                }
144                                catch (IllegalAccessException e) {
145                                        log.error(e.getCause(), "Could not get service class for %s", componentClass.getName());
146                                }
147                        }
148                }
149                if (bean == null && componentName != null && !("".equals(componentName)))
150                        bean = springContext.getBean(componentName);
151                
152            if (context != null)
153                context.getRequestMap().put(key, bean);
154            return bean;
155        }
156        catch (NoSuchBeanDefinitionException nexc) {
157                if (componentName != null && componentName.endsWith("Controller")) {
158                        try {
159                                int idx = componentName.lastIndexOf(".");
160                                String controllerName = idx > 0 
161                                        ? componentName.substring(0, idx+1) + componentName.substring(idx+1, idx+2).toUpperCase() + componentName.substring(idx+2)
162                                        : componentName.substring(0, 1).toUpperCase() + componentName.substring(1);
163                                bean = getSpringContext().getBean(controllerName);
164                    if (context != null)
165                        context.getRequestMap().put(key, bean);
166                                return bean;
167                        }
168                catch (NoSuchBeanDefinitionException nexc2) {
169                }
170                }
171                
172            String msg = "Spring service named '" + componentName + "' does not exist.";
173            ServiceException e = new ServiceException(msg, nexc);
174            throw e;
175        } 
176        catch (BeansException bexc) {
177            String msg = "Unable to create Spring service named '" + componentName + "'";
178            ServiceException e = new ServiceException(msg, bexc);
179            throw e;
180        }    
181    }
182    
183    @Override
184    @SuppressWarnings("unchecked")
185    public Set<Class<?>> findComponentClasses(String componentName, Class<?> componentClass) {
186        String key = COMPONENT_CLASS_ATTR + componentName;
187        Set<Class<?>> classes = null; 
188        GraniteContext context = GraniteContext.getCurrentInstance();
189        if (context != null) {
190                classes = (Set<Class<?>>)context.getRequestMap().get(key);
191                if (classes != null)
192                        return classes;
193        }
194        
195        Object bean = findComponent(componentName, componentClass);
196        classes = buildComponentClasses(bean);        
197        if (classes == null)
198                return null;
199        
200        if (context != null)
201                context.getRequestMap().put(key, classes);
202        return classes;
203    }
204    
205    protected Set<Class<?>> buildComponentClasses(Object bean) {
206        Set<Class<?>> classes = new HashSet<Class<?>>();
207        for (Class<?> i : bean.getClass().getInterfaces())
208                classes.add(i);
209        
210        try {
211                while (bean instanceof Advised)
212                        bean = ((Advised)bean).getTargetSource().getTarget();
213                
214                classes.add(AopUtils.getTargetClass(bean));
215        }
216        catch (Exception e) {
217            log.warn(e, "Could not get AOP class for component " + bean.getClass());
218                return null;
219        }
220        
221        return classes;
222    }
223    
224    protected boolean isBeanAnnotationPresent(Object bean, Class<? extends Annotation> annotationClass) {
225        if (bean.getClass().isAnnotationPresent(annotationClass))
226                return true;
227        
228        try {
229                while (bean instanceof Advised)
230                        bean = ((Advised)bean).getTargetSource().getTarget();
231                
232                if (AopUtils.getTargetClass(bean).isAnnotationPresent(annotationClass))
233                        return true;
234        }
235        catch (Exception e) {
236            log.warn(e, "Could not get AOP class for component " + bean.getClass());
237        }
238        
239        return false;
240    }
241    
242    protected boolean isBeanMethodAnnotationPresent(Object bean, String methodName, Class<?>[] methodArgTypes, Class<? extends Annotation> annotationClass) {
243        try {
244                Method m = bean.getClass().getMethod(methodName, methodArgTypes);
245                if (m.isAnnotationPresent(annotationClass))
246                        return true;
247                
248                while (bean instanceof Advised)
249                        bean = ((Advised)bean).getTargetSource().getTarget();
250                
251                m = AopUtils.getTargetClass(bean).getMethod(methodName, methodArgTypes);
252                if (m.isAnnotationPresent(annotationClass))
253                        return true;
254                }
255        catch (Exception e) {
256                log.warn("Could not find bean method", e);
257        }
258        
259        return false;
260    }
261
262    
263    @Override
264    public void prepareCall(ServiceInvocationContext context, IInvocationCall c, String componentName, Class<?> componentClass) {
265        DataContext.init();
266        
267        DataUpdatePostprocessor dupp = (DataUpdatePostprocessor)findComponent(null, DataUpdatePostprocessor.class);
268        if (dupp != null)
269                DataContext.get().setDataUpdatePostprocessor(dupp);
270    }
271
272    @Override
273    public IInvocationResult postCall(ServiceInvocationContext context, Object result, String componentName, Class<?> componentClass) {
274                List<ContextUpdate> results = null;
275        DataContext dataContext = DataContext.get();
276                Object[][] updates = dataContext != null ? dataContext.getUpdates() : null;
277                
278        InvocationResult ires = new InvocationResult(result, results);
279        if (isBeanAnnotationPresent(context.getBean(), BypassTideMerge.class))
280                ires.setMerge(false);
281        else if (isBeanMethodAnnotationPresent(context.getBean(), context.getMethod().getName(), context.getMethod().getParameterTypes(), BypassTideMerge.class))
282                        ires.setMerge(false);
283        
284        ires.setUpdates(updates);
285        
286        return ires;
287    }
288
289    @Override
290    public void postCallFault(ServiceInvocationContext context, Throwable t, String componentName, Class<?> componentClass) {        
291    }
292    
293    
294    public void setEntityManagerFactoryBeanName(String beanName) {
295        this.entityManagerFactoryBeanName = beanName;
296    }
297    
298    public void setPersistenceManagerBeanName(String beanName) {
299        this.persistenceManagerBeanName = beanName;
300    }
301    
302    /**
303     *  Create a TidePersistenceManager
304     *  
305     *  @param create create if not existent (can be false for use in entity merge)
306     *  @return a PersistenceContextManager
307     */
308    @Override
309    protected TidePersistenceManager getTidePersistenceManager(boolean create) {
310        if (!create)
311            return null;
312        
313        TidePersistenceManager pm = (TidePersistenceManager)GraniteContext.getCurrentInstance().getRequestMap().get(TidePersistenceManager.class.getName());
314        if (pm != null)
315                return pm;
316        
317        pm = createPersistenceManager();
318        GraniteContext.getCurrentInstance().getRequestMap().put(TidePersistenceManager.class.getName(), pm);
319        return pm;
320    }
321    
322    private TidePersistenceManager createPersistenceManager() {
323        if (persistenceManagerBeanName == null) {
324                if (entityManagerFactoryBeanName == null) {
325                        // No bean or entity manager factory specified 
326                        
327                        // 1. Look for a TidePersistenceManager bean
328                        Map<String, ?> pms = springContext.getBeansOfType(TidePersistenceManager.class);
329                        if (pms.size() > 1)
330                                throw new RuntimeException("More than one Tide persistence managers defined");
331                        
332                        if (pms.size() == 1)
333                                return (TidePersistenceManager)pms.values().iterator().next();
334                        
335                        // 2. If not found, try to determine the Spring transaction manager                     
336                        Map<String, ?> tms = springContext.getBeansOfType(PlatformTransactionManager.class);
337                        if (tms.isEmpty())
338                                log.debug("No Spring transaction manager found, specify a persistence-manager-bean-name or entity-manager-factory-bean-name");
339                        else if (tms.size() > 1)
340                                log.debug("More than one Spring transaction manager found, specify a persistence-manager-bean-name or entity-manager-factory-bean-name");
341                        else if (tms.size() == 1) {
342                                PlatformTransactionManager ptm = (PlatformTransactionManager)tms.values().iterator().next();
343                                        
344                                // If no entity manager, we define a Spring persistence manager 
345                                        // that will try to infer persistence info from the Spring transaction manager
346                                        return new SpringPersistenceManager(ptm);
347                        }
348                }
349                
350            String emfBeanName = entityManagerFactoryBeanName != null ? entityManagerFactoryBeanName : "entityManagerFactory";
351            try {
352                // Lookup the specified entity manager factory
353                Object emf = findComponent(emfBeanName, null);
354                
355                // Try to determine the Spring transaction manager
356                TideTransactionManager tm = null;
357                        Map<String, ?> ptms = springContext.getBeansOfType(PlatformTransactionManager.class);
358                        if (ptms.size() == 1) {
359                                log.debug("Found Spring transaction manager " + ptms.keySet().iterator().next());
360                                tm = new SpringTransactionManager((PlatformTransactionManager)ptms.values().iterator().next());
361                        }
362                
363                                Class<?> emfClass = TypeUtil.forName("javax.persistence.EntityManagerFactory");
364                    Class<?> pcmClass = TypeUtil.forName("org.granite.tide.data.JPAPersistenceManager");
365                    Constructor<?>[] cs = pcmClass.getConstructors();
366                    if (tm != null) {
367                        for (Constructor<?> c : cs) {
368                                if (c.getParameterTypes().length == 2 && emfClass.isAssignableFrom(c.getParameterTypes()[0])
369                                        && TideTransactionManager.class.isAssignableFrom(c.getParameterTypes()[1])) {
370                                        log.debug("Created JPA persistence manager with Spring transaction manager");
371                                        return (TidePersistenceManager)c.newInstance(((EntityManagerFactoryInfo)emf).getNativeEntityManagerFactory(), tm);
372                                }
373                        }
374                    }
375                    else {
376                            for (Constructor<?> c : cs) {
377                                if (c.getParameterTypes().length == 1 && emfClass.isAssignableFrom(c.getParameterTypes()[0])) {
378                                        log.debug("Created default JPA persistence manager");
379                                        return (TidePersistenceManager)c.newInstance(emf);
380                                }
381                            }
382                    }
383                    
384                    throw new RuntimeException("Default Tide persistence manager not found");
385            }
386            catch (ServiceException e) {
387                if (entityManagerFactoryBeanName != null)
388                        log.debug("EntityManagerFactory named %s not found, JPA support disabled", emfBeanName);
389                
390                return null;
391            }
392            catch (Exception e) {
393                throw new RuntimeException("Could not create default Tide persistence manager", e);
394            }
395        }
396        
397        return (TidePersistenceManager)findComponent(persistenceManagerBeanName, null);
398    }
399}