001 /**
002 * GRANITE DATA SERVICES
003 * Copyright (C) 2006-2013 GRANITE DATA SERVICES S.A.S.
004 *
005 * This file is part of the Granite Data Services Platform.
006 *
007 * Granite Data Services is free software; you can redistribute it and/or
008 * modify it under the terms of the GNU Lesser General Public
009 * License as published by the Free Software Foundation; either
010 * version 2.1 of the License, or (at your option) any later version.
011 *
012 * Granite Data Services is distributed in the hope that it will be useful,
013 * but WITHOUT ANY WARRANTY; without even the implied warranty of
014 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
015 * General Public License for more details.
016 *
017 * You should have received a copy of the GNU Lesser General Public
018 * License along with this library; if not, write to the Free Software
019 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
020 * USA, or see <http://www.gnu.org/licenses/>.
021 */
022 package org.granite.gravity.tomcat;
023
024 import java.util.List;
025
026 import javax.servlet.ServletConfig;
027 import javax.servlet.ServletException;
028 import javax.servlet.http.HttpServletRequest;
029 import javax.servlet.http.HttpSession;
030
031 import org.apache.catalina.websocket.StreamInbound;
032 import org.apache.catalina.websocket.WebSocketServlet;
033 import org.granite.context.GraniteContext;
034 import org.granite.gravity.Gravity;
035 import org.granite.gravity.GravityManager;
036 import org.granite.gravity.GravityServletUtil;
037 import org.granite.logging.Logger;
038 import org.granite.messaging.webapp.ServletGraniteContext;
039 import org.granite.util.ContentType;
040
041 import flex.messaging.messages.CommandMessage;
042 import flex.messaging.messages.Message;
043
044
045 public class TomcatWebSocketServlet extends WebSocketServlet {
046
047 private static final long serialVersionUID = 1L;
048
049 private static final Logger log = Logger.getLogger(TomcatWebSocketServlet.class);
050
051 @Override
052 public void init(ServletConfig config) throws ServletException {
053 super.init(config);
054
055 GravityServletUtil.init(config);
056 }
057
058 @Override
059 protected String selectSubProtocol(List<String> subProtocols) {
060 for (String protocol : subProtocols) {
061 if (protocol.startsWith("org.granite.gravity"))
062 return protocol;
063 }
064 return null;
065 }
066
067 @Override
068 protected StreamInbound createWebSocketInbound(String protocol, HttpServletRequest request) {
069 Gravity gravity = GravityManager.getGravity(getServletContext());
070 TomcatWebSocketChannelFactory channelFactory = new TomcatWebSocketChannelFactory(gravity);
071
072 try {
073 String connectMessageId = request.getHeader("connectId") != null ? request.getHeader("connectId") : request.getParameter("connectId");
074 String clientId = request.getHeader("GDSClientId") != null ? request.getHeader("GDSClientId") : request.getParameter("GDSClientId");
075 String clientType = request.getHeader("GDSClientType") != null ? request.getHeader("GDSClientType") : request.getParameter("GDSClientType");
076 String sessionId = null;
077 HttpSession session = request.getSession(false);
078 if (session != null) {
079 ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(),
080 getServletContext(), session, clientType);
081
082 sessionId = session.getId();
083 }
084 else if (request.getCookies() != null) {
085 for (int i = 0; i < request.getCookies().length; i++) {
086 if ("JSESSIONID".equals(request.getCookies()[i].getName())) {
087 sessionId = request.getCookies()[i].getValue();
088 break;
089 }
090 }
091
092 ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(),
093 getServletContext(), sessionId, clientType);
094 }
095 else {
096 ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(),
097 getServletContext(), (String)null, clientType);
098 }
099
100 log.info("WebSocket connection started %s clientId %s sessionId %s", protocol, clientId, sessionId);
101
102 CommandMessage pingMessage = new CommandMessage();
103 pingMessage.setMessageId(connectMessageId != null ? connectMessageId : "OPEN_CONNECTION");
104 pingMessage.setOperation(CommandMessage.CLIENT_PING_OPERATION);
105 if (clientId != null)
106 pingMessage.setClientId(clientId);
107
108 Message ackMessage = gravity.handleMessage(channelFactory, pingMessage);
109 if (sessionId != null)
110 ackMessage.setHeader("JSESSIONID", sessionId);
111
112 TomcatWebSocketChannel channel = gravity.getChannel(channelFactory, (String)ackMessage.getClientId());
113 channel.setSession(session);
114
115 String ctype = request.getContentType();
116 if (ctype == null && protocol.length() > "org.granite.gravity".length())
117 ctype = "application/x-" + protocol.substring("org.granite.gravity.".length());
118
119 ContentType contentType = ContentType.forMimeType(ctype);
120 if (contentType == null) {
121 log.warn("No (or unsupported) content type in request: %s", request.getContentType());
122 contentType = ContentType.AMF;
123 }
124 channel.setContentType(contentType);
125
126 if (!ackMessage.getClientId().equals(clientId))
127 channel.setConnectAckMessage(ackMessage);
128
129 return channel.getStreamInbound();
130 }
131 finally {
132 GraniteContext.release();
133 }
134 }
135 //
136 // @Override
137 // protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
138 //
139 // // Information required to send the server handshake message
140 // String key;
141 // String subProtocol = null;
142 // List<String> extensions = Collections.emptyList();
143 //
144 // if (!headerContainsToken(req, "upgrade", "websocket")) {
145 // resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
146 // return;
147 // }
148 //
149 // if (!headerContainsToken(req, "connection", "upgrade")) {
150 // resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
151 // return;
152 // }
153 //
154 // if (!headerContainsToken(req, "sec-websocket-version", "13")) {
155 // resp.setStatus(426);
156 // resp.setHeader("Sec-WebSocket-Version", "13");
157 // return;
158 // }
159 //
160 // key = req.getHeader("Sec-WebSocket-Key");
161 // if (key == null) {
162 // resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
163 // return;
164 // }
165 //
166 // String origin = req.getHeader("Origin");
167 // if (!verifyOrigin(origin)) {
168 // resp.sendError(HttpServletResponse.SC_FORBIDDEN);
169 // return;
170 // }
171 //
172 // // Fix for Tomcat-7.0.29 bad header name (was Sec-WebSocket-Protocol-Client")
173 // List<String> subProtocols = getTokensFromHeader(req, "Sec-WebSocket-Protocol");
174 // if (!subProtocols.isEmpty())
175 // subProtocol = selectSubProtocol(subProtocols);
176 //
177 // // TODO Read client handshake - Sec-WebSocket-Extensions
178 //
179 // // TODO Extensions require the ability to specify something (API TBD)
180 // // that can be passed to the Tomcat internals and process extension
181 // // data present when the frame is fragmented.
182 //
183 // // If we got this far, all is good. Accept the connection.
184 // resp.setHeader("Upgrade", "websocket");
185 // resp.setHeader("Connection", "upgrade");
186 // resp.setHeader("Sec-WebSocket-Accept", getWebSocketAccept(key));
187 // if (subProtocol != null)
188 // resp.setHeader("Sec-WebSocket-Protocol", subProtocol);
189 //
190 // if (!extensions.isEmpty()) {
191 // // TODO
192 // }
193 //
194 // WsHttpServletRequestWrapper wrapper = new WsHttpServletRequestWrapper(req);
195 // StreamInbound inbound = createWebSocketInbound(subProtocol, wrapper);
196 // wrapper.invalidate();
197 //
198 // // Hack to avoid chunked transfer
199 // resp.setContentLength(((TomcatWebSocketChannel.MessageInboundImpl)inbound).getAckLength());
200 //
201 // // Small hack until the Servlet API provides a way to do this.
202 // ServletRequest inner = req;
203 // // Unwrap the request
204 // while (inner instanceof ServletRequestWrapper)
205 // inner = ((ServletRequestWrapper)inner).getRequest();
206 //
207 // if (inner instanceof RequestFacade)
208 // ((RequestFacade)inner).doUpgrade(inbound);
209 // else
210 // resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, sm.getString("servlet.reqUpgradeFail"));
211 // }
212 //
213 //
214 // private boolean headerContainsToken(HttpServletRequest req,
215 // String headerName, String target) {
216 // Enumeration<String> headers = req.getHeaders(headerName);
217 // while (headers.hasMoreElements()) {
218 // String header = headers.nextElement();
219 // String[] tokens = header.split(",");
220 // for (String token : tokens) {
221 // if (target.equalsIgnoreCase(token.trim())) {
222 // return true;
223 // }
224 // }
225 // }
226 // return false;
227 // }
228 //
229 // private List<String> getTokensFromHeader(HttpServletRequest req,
230 // String headerName) {
231 // List<String> result = new ArrayList<String>();
232 //
233 // Enumeration<String> headers = req.getHeaders(headerName);
234 // while (headers.hasMoreElements()) {
235 // String header = headers.nextElement();
236 // String[] tokens = header.split(",");
237 // for (String token : tokens) {
238 // result.add(token.trim());
239 // }
240 // }
241 // return result;
242 // }
243 //
244 // private String getWebSocketAccept(String key) throws ServletException {
245 //
246 // MessageDigest sha1Helper = sha1Helpers.poll();
247 // if (sha1Helper == null) {
248 // try {
249 // sha1Helper = MessageDigest.getInstance("SHA1");
250 // } catch (NoSuchAlgorithmException e) {
251 // throw new ServletException(e);
252 // }
253 // }
254 //
255 // sha1Helper.reset();
256 // sha1Helper.update(key.getBytes(B2CConverter.ISO_8859_1));
257 // String result = Base64.encode(sha1Helper.digest(WS_ACCEPT));
258 //
259 // sha1Helpers.add(sha1Helper);
260 //
261 // return result;
262 // }
263 }