001    /*
002      GRANITE DATA SERVICES
003      Copyright (C) 2011 GRANITE DATA SERVICES S.A.S.
004    
005      This file is part of Granite Data Services.
006    
007      Granite Data Services is free software; you can redistribute it and/or modify
008      it under the terms of the GNU Library General Public License as published by
009      the Free Software Foundation; either version 2 of the License, or (at your
010      option) any later version.
011    
012      Granite Data Services is distributed in the hope that it will be useful, but
013      WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
014      FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public License
015      for more details.
016    
017      You should have received a copy of the GNU Library General Public License
018      along with this library; if not, see <http://www.gnu.org/licenses/>.
019    */
020    
021    package org.granite.cdi;
022    
023    import java.util.Iterator;
024    import java.util.Map;
025    
026    import javax.enterprise.context.Conversation;
027    import javax.enterprise.inject.spi.Bean;
028    import javax.enterprise.inject.spi.BeanManager;
029    import javax.naming.InitialContext;
030    import javax.naming.NameNotFoundException;
031    import javax.servlet.ServletRequestEvent;
032    
033    import org.granite.context.GraniteContext;
034    import org.granite.logging.Logger;
035    import org.granite.messaging.amf.process.AMF3MessageInterceptor;
036    import org.granite.messaging.service.ServiceException;
037    import org.granite.messaging.webapp.HttpGraniteContext;
038    import org.granite.messaging.webapp.HttpServletRequestParamWrapper;
039    import org.granite.tide.cdi.ConversationState;
040    import org.granite.tide.cdi.EventState;
041    import org.granite.tide.cdi.SessionState;
042    import org.granite.util.ClassUtil;
043    import org.jboss.weld.servlet.WeldListener;
044    
045    import flex.messaging.messages.Message;
046    
047    
048    public class CDIInterceptor implements AMF3MessageInterceptor {
049            
050            private static final Logger log = Logger.getLogger(CDIInterceptor.class);
051    
052        private static final String CONVERSATION_ID = "conversationId";
053        private static final String IS_LONG_RUNNING_CONVERSATION = "isLongRunningConversation";
054        private static final String WAS_LONG_RUNNING_CONVERSATION_CREATED = "wasLongRunningConversationCreated";
055        private static final String WAS_LONG_RUNNING_CONVERSATION_ENDED = "wasLongRunningConversationEnded";
056            
057        private CDIConversationManager conversationManager;
058        
059        
060        public CDIInterceptor() {
061            try {
062                    Thread.currentThread().getContextClassLoader().loadClass("org.jboss.weld.context.http.HttpConversationContext");
063                    conversationManager = ClassUtil.newInstance("org.granite.cdi.Weld11ConversationManager", CDIConversationManager.class);
064                    log.info("Detected Weld 1.1");
065            }
066            catch (Exception e) {
067                    try {
068                            conversationManager = ClassUtil.newInstance("org.granite.cdi.Weld10ConversationManager", CDIConversationManager.class);
069                            log.info("Detected Weld 1.0");
070                    }
071                    catch (Exception f) {
072                            throw new RuntimeException("Could not load conversation manager for CDI implementation", f);
073                    }
074            }
075        }
076        
077        
078        public static BeanManager lookupBeanManager() {
079                    HttpGraniteContext context = (HttpGraniteContext)GraniteContext.getCurrentInstance();
080                    BeanManager manager = (BeanManager)context.getServletContext().getAttribute(BeanManager.class.getName());
081                    if (manager != null)
082                            return manager;         
083                    manager = (BeanManager)context.getServletContext().getAttribute("org.jboss.weld.environment.servlet.javax.enterprise.inject.spi.BeanManager");
084                    if (manager != null)
085                            return manager;
086                    
087                    InitialContext ic = null;
088                try {
089                            ic = new InitialContext();
090                    // Standard JNDI binding
091                    return (BeanManager)ic.lookup("java:comp/BeanManager");
092                }
093                catch (NameNotFoundException e) {
094                    if (ic == null)
095                            throw new RuntimeException("No InitialContext");
096                    
097                    // Weld/Tomcat
098                    try {
099                            return (BeanManager)ic.lookup("java:comp/env/BeanManager"); 
100                    }
101                    catch (Exception e1) {          
102                            // JBoss 5/6 (maybe obsolete in Weld 1.0+)
103                            try {
104                                    return (BeanManager)ic.lookup("java:app/BeanManager");
105                            }
106                        catch (Exception e2) {
107                            throw new RuntimeException("Could not find Bean Manager", e2);
108                        }
109                    }
110                }
111                catch (Exception e) {
112                    throw new RuntimeException("Could not find Bean Manager", e);
113                }
114        }
115        
116        
117            private WeldListener listener = new WeldListener();
118            
119            private static final String MESSAGECOUNT_ATTR = CDIInterceptor.class.getName() + "_messageCount";
120            private static final String REQUESTWRAPPER_ATTR = CDIInterceptor.class.getName() + "_requestWrapper";
121        
122        
123            public void before(Message amf3RequestMessage) {
124                    if (log.isTraceEnabled())
125                            log.trace("Pre processing of request message: %s", amf3RequestMessage);
126                    
127                    try {
128                            GraniteContext context = GraniteContext.getCurrentInstance();
129                            
130                            if (context instanceof HttpGraniteContext) {
131                                    HttpGraniteContext httpContext = ((HttpGraniteContext)context);
132                                    Integer wrapCount = (Integer)httpContext.getRequest().getAttribute(MESSAGECOUNT_ATTR);
133                                    if (wrapCount == null) {
134                                            log.debug("Clearing default Weld request context");
135                                    ServletRequestEvent event = new ServletRequestEvent(httpContext.getServletContext(), httpContext.getRequest());
136                                            listener.requestDestroyed(event);
137                                            httpContext.getRequest().setAttribute(MESSAGECOUNT_ATTR, 1);
138                                    }
139                                    else
140                                            httpContext.getRequest().setAttribute(MESSAGECOUNT_ATTR, wrapCount+1);
141                                    
142                            log.debug("Initializing wrapped AMF request");
143                            
144                        HttpServletRequestParamWrapper requestWrapper = new HttpServletRequestParamWrapper(httpContext.getRequest());
145                        httpContext.getRequest().setAttribute(REQUESTWRAPPER_ATTR, requestWrapper);
146                                    
147                            // Now export the headers - copy the headers to request object
148                            Map<String, Object> headerMap = amf3RequestMessage.getHeaders();
149                            if (headerMap != null && headerMap.size() > 0) {
150                                    Iterator<String> headerKeys = headerMap.keySet().iterator();
151                                    while (headerKeys.hasNext()) {
152                                            String key = headerKeys.next();
153                                            String value = headerMap.get(key) == null ? null : headerMap.get(key).toString();
154                                            if (value != null)
155                                                    requestWrapper.setParameter(key, value);
156                                    }
157                            }
158                        
159                            ServletRequestEvent event = new ServletRequestEvent(((HttpGraniteContext)context).getServletContext(), requestWrapper);
160                            listener.requestInitialized(event);
161                        
162                            // Initialize CDI Context
163                                String conversationId = (String)amf3RequestMessage.getHeader(CONVERSATION_ID);
164                                
165                                BeanManager beanManager = lookupBeanManager();
166                                
167                                Conversation conversation = conversationManager.initConversation(beanManager, conversationId);
168                                
169                                @SuppressWarnings("unchecked")
170                                Bean<EventState> eventBean = (Bean<EventState>)beanManager.getBeans(EventState.class).iterator().next();
171                                EventState eventState = (EventState)beanManager.getReference(eventBean, EventState.class, beanManager.createCreationalContext(eventBean));
172                                if (!conversation.isTransient())
173                                    eventState.setWasLongRunning(true);
174                                
175                                if (conversationId != null && conversation.isTransient()) {
176                                        log.debug("Starting conversation " + conversationId);
177                                        conversation.begin(conversationId);
178                                }
179                                    
180                            if (Boolean.TRUE.toString().equals(amf3RequestMessage.getHeader("org.granite.tide.isFirstCall"))) {
181                                    @SuppressWarnings("unchecked")
182                                    Bean<SessionState> ssBean = (Bean<SessionState>)beanManager.getBeans(SessionState.class).iterator().next();
183                                    ((SessionState)beanManager.getReference(ssBean, SessionState.class, beanManager.createCreationalContext(ssBean))).setFirstCall(true);
184                            }
185                                    
186                            if (Boolean.TRUE.toString().equals(amf3RequestMessage.getHeader("org.granite.tide.isFirstConversationCall")) && !conversation.isTransient()) {
187                                    @SuppressWarnings("unchecked")
188                                    Bean<ConversationState> csBean = (Bean<ConversationState>)beanManager.getBeans(ConversationState.class).iterator().next();
189                                    ((ConversationState)beanManager.getReference(csBean, ConversationState.class, beanManager.createCreationalContext(csBean))).setFirstCall(true);
190                            }
191                            }
192                    }
193                    catch(Exception e) {
194                log.error(e, "Exception while pre processing the request message.");
195                throw new ServiceException("Error while pre processing the request message - " + e.getMessage());
196                    }
197            }
198    
199    
200            public void after(Message amf3RequestMessage, Message amf3ResponseMessage) {            
201                    try {
202                            if (log.isTraceEnabled())
203                                    log.trace("Post processing of response message: %s", amf3ResponseMessage);
204    
205                            GraniteContext context = GraniteContext.getCurrentInstance();
206                            
207                            if (context instanceof HttpGraniteContext) {
208                                BeanManager beanManager = lookupBeanManager();
209                                    try {
210                                            // Add conversation management headers to response
211                                            if (amf3ResponseMessage != null) {
212                                                @SuppressWarnings("unchecked")
213                                                Bean<Conversation> conversationBean = (Bean<Conversation>)beanManager.getBeans(Conversation.class).iterator().next();
214                                                Conversation conversation = (Conversation)beanManager.getReference(conversationBean, Conversation.class, beanManager.createCreationalContext(conversationBean));
215                                                
216                                                @SuppressWarnings("unchecked")
217                                                Bean<EventState> eventBean = (Bean<EventState>)beanManager.getBeans(EventState.class).iterator().next();
218                                                EventState eventState = (EventState)beanManager.getReference(eventBean, EventState.class, beanManager.createCreationalContext(eventBean));
219                                                if (eventState.wasLongRunning() && !conversation.isTransient())
220                                                    amf3ResponseMessage.setHeader(WAS_LONG_RUNNING_CONVERSATION_ENDED, true);
221                                                    
222                                        if (eventState.wasCreated() && !conversation.isTransient())
223                                            amf3ResponseMessage.setHeader(WAS_LONG_RUNNING_CONVERSATION_CREATED, true);
224                                        
225                                        amf3ResponseMessage.setHeader(CONVERSATION_ID, conversation.getId());
226                                                    
227                                        amf3ResponseMessage.setHeader(IS_LONG_RUNNING_CONVERSATION, !conversation.isTransient());
228                                            }
229                                    }
230                                    finally {
231                                            conversationManager.destroyConversation(beanManager);
232                                            
233                                            HttpGraniteContext httpContext = ((HttpGraniteContext)context);
234                                        
235                                            // Destroy the CDI context
236                                            HttpServletRequestParamWrapper requestWrapper = (HttpServletRequestParamWrapper)httpContext.getRequest().getAttribute(REQUESTWRAPPER_ATTR);
237                                            httpContext.getRequest().removeAttribute(REQUESTWRAPPER_ATTR);
238                                    ServletRequestEvent event = new ServletRequestEvent(httpContext.getServletContext(), requestWrapper);
239                                    listener.requestDestroyed(event);
240                                        
241                                    log.debug("Destroying wrapped CDI AMF request");
242                                    
243                                            Integer wrapCount = (Integer)httpContext.getRequest().getAttribute(MESSAGECOUNT_ATTR);
244                                            if (wrapCount == 1) {
245                                                    log.debug("Restoring default Weld request context");
246                                            event = new ServletRequestEvent(((HttpGraniteContext)context).getServletContext(), httpContext.getRequest());
247                                                    listener.requestInitialized(event);
248                                                    httpContext.getRequest().removeAttribute(MESSAGECOUNT_ATTR);
249                                            }
250                                            else
251                                                    httpContext.getRequest().setAttribute(MESSAGECOUNT_ATTR, wrapCount-1);
252                                            
253                                    }
254                            }
255                    }
256                    catch (Exception e) {
257                log.error(e, "Exception while post processing the response message.");
258                throw new ServiceException("Error while post processing the response message - " + e.getMessage());
259                    }
260            }
261    }