001/**
002 *   GRANITE DATA SERVICES
003 *   Copyright (C) 2006-2013 GRANITE DATA SERVICES S.A.S.
004 *
005 *   This file is part of the Granite Data Services Platform.
006 *
007 *   Granite Data Services is free software; you can redistribute it and/or
008 *   modify it under the terms of the GNU Lesser General Public
009 *   License as published by the Free Software Foundation; either
010 *   version 2.1 of the License, or (at your option) any later version.
011 *
012 *   Granite Data Services is distributed in the hope that it will be useful,
013 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
014 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
015 *   General Public License for more details.
016 *
017 *   You should have received a copy of the GNU Lesser General Public
018 *   License along with this library; if not, write to the Free Software
019 *   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
020 *   USA, or see <http://www.gnu.org/licenses/>.
021 */
022package org.granite.gravity.tomcat;
023
024import java.util.List;
025
026import javax.servlet.ServletConfig;
027import javax.servlet.ServletException;
028import javax.servlet.http.HttpServletRequest;
029import javax.servlet.http.HttpSession;
030
031import org.apache.catalina.websocket.StreamInbound;
032import org.apache.catalina.websocket.WebSocketServlet;
033import org.granite.context.GraniteContext;
034import org.granite.gravity.Gravity;
035import org.granite.gravity.GravityManager;
036import org.granite.gravity.GravityServletUtil;
037import org.granite.logging.Logger;
038import org.granite.messaging.webapp.ServletGraniteContext;
039import org.granite.util.ContentType;
040
041import flex.messaging.messages.CommandMessage;
042import flex.messaging.messages.Message;
043
044
045public class TomcatWebSocketServlet extends WebSocketServlet {
046        
047        private static final long serialVersionUID = 1L;
048        
049        private static final Logger log = Logger.getLogger(TomcatWebSocketServlet.class);
050        
051        @Override
052        public void init(ServletConfig config) throws ServletException {
053                super.init(config);
054                
055                GravityServletUtil.init(config);
056        }
057        
058        @Override
059        protected String selectSubProtocol(List<String> subProtocols) {
060        for (String protocol : subProtocols) {
061            if (protocol.startsWith("org.granite.gravity"))
062                return protocol;
063        }
064                return null;
065        }
066        
067        @Override
068        protected StreamInbound createWebSocketInbound(String protocol, HttpServletRequest request) {
069                Gravity gravity = GravityManager.getGravity(getServletContext());
070                TomcatWebSocketChannelFactory channelFactory = new TomcatWebSocketChannelFactory(gravity);
071                
072                try {
073                        String connectMessageId = request.getHeader("connectId") != null ? request.getHeader("connectId") : request.getParameter("connectId");
074                        String clientId = request.getHeader("GDSClientId") != null ? request.getHeader("GDSClientId") : request.getParameter("GDSClientId");
075                        String clientType = request.getHeader("GDSClientType") != null ? request.getHeader("GDSClientType") : request.getParameter("GDSClientType");
076                        String sessionId = null;
077                        HttpSession session = request.getSession(false);
078                        if (session != null) {
079                        ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), 
080                                getServletContext(), session, clientType);
081                        
082                                sessionId = session.getId();
083                        }
084                        else if (request.getCookies() != null) {
085                                for (int i = 0; i < request.getCookies().length; i++) {
086                                        if ("JSESSIONID".equals(request.getCookies()[i].getName())) {
087                                                sessionId = request.getCookies()[i].getValue();
088                                                break;
089                                        }
090                                }                               
091                                
092                                ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), 
093                                getServletContext(), sessionId, clientType); 
094                        }
095            else {
096                ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(),
097                        getServletContext(), (String)null, clientType);
098            }
099                        
100                        log.info("WebSocket connection started %s clientId %s sessionId %s", protocol, clientId, sessionId);
101                        
102                        CommandMessage pingMessage = new CommandMessage();
103                        pingMessage.setMessageId(connectMessageId != null ? connectMessageId : "OPEN_CONNECTION");
104                        pingMessage.setOperation(CommandMessage.CLIENT_PING_OPERATION);
105                        if (clientId != null)
106                                pingMessage.setClientId(clientId);
107                        
108                        Message ackMessage = gravity.handleMessage(channelFactory, pingMessage);
109            if (sessionId != null)
110                ackMessage.setHeader("JSESSIONID", sessionId);
111
112                        TomcatWebSocketChannel channel = gravity.getChannel(channelFactory, (String)ackMessage.getClientId());
113            channel.setSession(session);
114
115            String ctype = request.getContentType();
116            if (ctype == null && protocol.length() > "org.granite.gravity".length())
117                ctype = "application/x-" + protocol.substring("org.granite.gravity.".length());
118
119            ContentType contentType = ContentType.forMimeType(ctype);
120                        if (contentType == null) {
121                                log.warn("No (or unsupported) content type in request: %s", request.getContentType());
122                                contentType = ContentType.AMF;
123                        }
124                        channel.setContentType(contentType);
125                        
126                        if (!ackMessage.getClientId().equals(clientId))
127                                channel.setConnectAckMessage(ackMessage);
128                        
129                        return channel.getStreamInbound();
130                }
131                finally {
132                        GraniteContext.release();
133                }
134        }
135//
136//    @Override
137//    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
138//
139//        // Information required to send the server handshake message
140//        String key;
141//        String subProtocol = null;
142//        List<String> extensions = Collections.emptyList();
143//
144//        if (!headerContainsToken(req, "upgrade", "websocket")) {
145//            resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
146//            return;
147//        }
148//
149//        if (!headerContainsToken(req, "connection", "upgrade")) {
150//            resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
151//            return;
152//        }
153//
154//        if (!headerContainsToken(req, "sec-websocket-version", "13")) {
155//            resp.setStatus(426);
156//            resp.setHeader("Sec-WebSocket-Version", "13");
157//            return;
158//        }
159//
160//        key = req.getHeader("Sec-WebSocket-Key");
161//        if (key == null) {
162//            resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
163//            return;
164//        }
165//
166//        String origin = req.getHeader("Origin");
167//        if (!verifyOrigin(origin)) {
168//            resp.sendError(HttpServletResponse.SC_FORBIDDEN);
169//            return;
170//        }
171//
172//        // Fix for Tomcat-7.0.29 bad header name (was Sec-WebSocket-Protocol-Client")
173//        List<String> subProtocols = getTokensFromHeader(req, "Sec-WebSocket-Protocol");
174//        if (!subProtocols.isEmpty())
175//            subProtocol = selectSubProtocol(subProtocols);
176//
177//        // TODO Read client handshake - Sec-WebSocket-Extensions
178//
179//        // TODO Extensions require the ability to specify something (API TBD)
180//        //      that can be passed to the Tomcat internals and process extension
181//        //      data present when the frame is fragmented.
182//
183//        // If we got this far, all is good. Accept the connection.
184//        resp.setHeader("Upgrade", "websocket");
185//        resp.setHeader("Connection", "upgrade");
186//        resp.setHeader("Sec-WebSocket-Accept", getWebSocketAccept(key));
187//        if (subProtocol != null)
188//            resp.setHeader("Sec-WebSocket-Protocol", subProtocol);
189//
190//        if (!extensions.isEmpty()) {
191//            // TODO
192//        }
193//
194//        WsHttpServletRequestWrapper wrapper = new WsHttpServletRequestWrapper(req);
195//        StreamInbound inbound = createWebSocketInbound(subProtocol, wrapper);
196//        wrapper.invalidate();
197//
198//        // Hack to avoid chunked transfer
199//        resp.setContentLength(((TomcatWebSocketChannel.MessageInboundImpl)inbound).getAckLength());
200//
201//        // Small hack until the Servlet API provides a way to do this.
202//        ServletRequest inner = req;
203//        // Unwrap the request
204//        while (inner instanceof ServletRequestWrapper)
205//            inner = ((ServletRequestWrapper)inner).getRequest();
206//
207//        if (inner instanceof RequestFacade)
208//            ((RequestFacade)inner).doUpgrade(inbound);
209//        else
210//            resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, sm.getString("servlet.reqUpgradeFail"));
211//    }
212//
213//
214//    private boolean headerContainsToken(HttpServletRequest req,
215//            String headerName, String target) {
216//        Enumeration<String> headers = req.getHeaders(headerName);
217//        while (headers.hasMoreElements()) {
218//            String header = headers.nextElement();
219//            String[] tokens = header.split(",");
220//            for (String token : tokens) {
221//                if (target.equalsIgnoreCase(token.trim())) {
222//                    return true;
223//                }
224//            }
225//        }
226//        return false;
227//    }
228//
229//    private List<String> getTokensFromHeader(HttpServletRequest req,
230//            String headerName) {
231//        List<String> result = new ArrayList<String>();
232//
233//        Enumeration<String> headers = req.getHeaders(headerName);
234//        while (headers.hasMoreElements()) {
235//            String header = headers.nextElement();
236//            String[] tokens = header.split(",");
237//            for (String token : tokens) {
238//                result.add(token.trim());
239//            }
240//        }
241//        return result;
242//    }
243//
244//      private String getWebSocketAccept(String key) throws ServletException {
245//
246//        MessageDigest sha1Helper = sha1Helpers.poll();
247//        if (sha1Helper == null) {
248//            try {
249//                sha1Helper = MessageDigest.getInstance("SHA1");
250//            } catch (NoSuchAlgorithmException e) {
251//                throw new ServletException(e);
252//            }
253//        }
254//
255//        sha1Helper.reset();
256//        sha1Helper.update(key.getBytes(B2CConverter.ISO_8859_1));
257//        String result = Base64.encode(sha1Helper.digest(WS_ACCEPT));
258//
259//        sha1Helpers.add(sha1Helper);
260//
261//        return result;
262//    }
263}