001 /**
002 * Copyright (C) 2010-2011, FuseSource Corp. All rights reserved.
003 *
004 * http://fusesource.com
005 *
006 * The software in this package is published under the terms of the
007 * CDDL license a copy of which has been included with this distribution
008 * in the license.txt file.
009 */
010 package org.fusesource.hawtdispatch.transport;
011
012 import javax.net.ssl.*;
013 import java.io.EOFException;
014 import java.io.IOException;
015 import java.net.Socket;
016 import java.net.URI;
017 import java.nio.ByteBuffer;
018 import java.nio.channels.*;
019 import java.security.cert.Certificate;
020 import java.security.cert.X509Certificate;
021 import java.util.ArrayList;
022 import java.util.concurrent.Executor;
023
024 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
025 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
026 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
027 import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
028
029 /**
030 * An SSL Transport for secure communications.
031 *
032 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
033 */
034 public class SslTransport extends TcpTransport {
035
036
037 /**
038 * Maps uri schemes to a protocol algorithm names.
039 * Valid algorithm names listed at:
040 * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
041 */
042 public static String protocol(String scheme) {
043 if( scheme.equals("tls") ) {
044 return "TLS";
045 } else if( scheme.startsWith("tlsv") ) {
046 return "TLSv"+scheme.substring(4);
047 } else if( scheme.equals("ssl") ) {
048 return "SSL";
049 } else if( scheme.startsWith("sslv") ) {
050 return "SSLv"+scheme.substring(4);
051 }
052 return null;
053 }
054
055 private SSLContext sslContext;
056 private SSLEngine engine;
057
058 private ByteBuffer readBuffer;
059 private boolean readUnderflow;
060
061 private ByteBuffer writeBuffer;
062 private boolean writeFlushing;
063
064 private ByteBuffer readOverflowBuffer;
065 private SSLChannel ssl_channel = new SSLChannel();
066
067 private Executor blockingExecutor;
068
069 public void setSSLContext(SSLContext ctx) {
070 this.sslContext = ctx;
071 }
072
073 /**
074 * Allows subclasses of TcpTransportFactory to create custom instances of
075 * TcpTransport.
076 */
077 public static SslTransport createTransport(URI uri) throws Exception {
078 String protocol = protocol(uri.getScheme());
079 if( protocol !=null ) {
080 SslTransport rc = new SslTransport();
081 rc.setSSLContext(SSLContext.getInstance(protocol));
082 return rc;
083 }
084 return null;
085 }
086
087 public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
088
089 public int write(ByteBuffer plain) throws IOException {
090 return secure_write(plain);
091 }
092
093 public int read(ByteBuffer plain) throws IOException {
094 return secure_read(plain);
095 }
096
097 public boolean isOpen() {
098 return getSocketChannel().isOpen();
099 }
100
101 public void close() throws IOException {
102 getSocketChannel().close();
103 }
104
105 public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
106 if(offset+length > srcs.length || length<0 || offset<0) {
107 throw new IndexOutOfBoundsException();
108 }
109 long rc=0;
110 for (int i = 0; i < length; i++) {
111 ByteBuffer src = srcs[offset+i];
112 if(src.hasRemaining()) {
113 rc += write(src);
114 }
115 if( src.hasRemaining() ) {
116 return rc;
117 }
118 }
119 return rc;
120 }
121
122 public long write(ByteBuffer[] srcs) throws IOException {
123 return write(srcs, 0, srcs.length);
124 }
125
126 public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
127 if(offset+length > dsts.length || length<0 || offset<0) {
128 throw new IndexOutOfBoundsException();
129 }
130 long rc=0;
131 for (int i = 0; i < length; i++) {
132 ByteBuffer dst = dsts[offset+i];
133 if(dst.hasRemaining()) {
134 rc += read(dst);
135 }
136 if( dst.hasRemaining() ) {
137 return rc;
138 }
139 }
140 return rc;
141 }
142
143 public long read(ByteBuffer[] dsts) throws IOException {
144 return read(dsts, 0, dsts.length);
145 }
146
147 public Socket socket() {
148 SocketChannel c = channel;
149 if( c == null ) {
150 return null;
151 }
152 return c.socket();
153 }
154 }
155
156 public SSLSession getSSLSession() {
157 return engine==null ? null : engine.getSession();
158 }
159
160 public X509Certificate[] getPeerX509Certificates() {
161 if( engine==null ) {
162 return null;
163 }
164 try {
165 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
166 for( Certificate c:engine.getSession().getPeerCertificates() ) {
167 if(c instanceof X509Certificate) {
168 rc.add((X509Certificate) c);
169 }
170 }
171 return rc.toArray(new X509Certificate[rc.size()]);
172 } catch (SSLPeerUnverifiedException e) {
173 return null;
174 }
175 }
176
177 @Override
178 public void connecting(URI remoteLocation, URI localLocation) throws Exception {
179 assert engine == null;
180 engine = sslContext.createSSLEngine();
181 engine.setUseClientMode(true);
182 super.connecting(remoteLocation, localLocation);
183 }
184
185 @Override
186 public void connected(SocketChannel channel) throws Exception {
187 if (engine == null) {
188 engine = sslContext.createSSLEngine();
189 engine.setUseClientMode(false);
190 engine.setWantClientAuth(true);
191 }
192 super.connected(channel);
193 }
194
195 @Override
196 protected void initializeChannel() throws Exception {
197 super.initializeChannel();
198 SSLSession session = engine.getSession();
199 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
200 readBuffer.flip();
201 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
202 }
203
204 @Override
205 protected void onConnected() throws IOException {
206 super.onConnected();
207 engine.beginHandshake();
208 handshake();
209 }
210
211 @Override
212 public void flush() {
213 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
214 handshake();
215 } else {
216 super.flush();
217 }
218 }
219
220 @Override
221 protected void drainInbound() {
222 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
223 handshake();
224 } else {
225 super.drainInbound();
226 }
227 }
228
229 /**
230 * @return true if fully flushed.
231 * @throws IOException
232 */
233 protected boolean transportFlush() throws IOException {
234 while (true) {
235 if(writeFlushing) {
236 int count = super.writeChannel().write(writeBuffer);
237 if( !writeBuffer.hasRemaining() ) {
238 writeBuffer.clear();
239 writeFlushing = false;
240 suspendWrite();
241 return true;
242 } else {
243 return false;
244 }
245 } else {
246 if( writeBuffer.position()!=0 ) {
247 writeBuffer.flip();
248 writeFlushing = true;
249 resumeWrite();
250 } else {
251 return true;
252 }
253 }
254 }
255 }
256
257 private int secure_write(ByteBuffer plain) throws IOException {
258 if( !transportFlush() ) {
259 // can't write anymore until the write_secured_buffer gets fully flushed out..
260 return 0;
261 }
262 int rc = 0;
263 while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
264 SSLEngineResult result = engine.wrap(plain, writeBuffer);
265 assert result.getStatus()!= BUFFER_OVERFLOW;
266 rc += result.bytesConsumed();
267 if( !transportFlush() ) {
268 break;
269 }
270 }
271 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
272 dispatchQueue.execute(new Runnable() {
273 public void run() {
274 handshake();
275 }
276 });
277 }
278 return rc;
279 }
280
281 private int secure_read(ByteBuffer plain) throws IOException {
282 int rc=0;
283 while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
284 if( readOverflowBuffer !=null ) {
285 if( plain.hasRemaining() ) {
286 // lets drain the overflow buffer before trying to suck down anymore
287 // network bytes.
288 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
289 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
290 readOverflowBuffer.position(readOverflowBuffer.position()+size);
291 if( !readOverflowBuffer.hasRemaining() ) {
292 readOverflowBuffer = null;
293 }
294 rc += size;
295 } else {
296 return rc;
297 }
298 } else if( readUnderflow ) {
299 int count = super.readChannel().read(readBuffer);
300 if( count == -1 ) { // peer closed socket.
301 if (rc==0) {
302 return -1;
303 } else {
304 return rc;
305 }
306 }
307 if( count==0 ) { // no data available right now.
308 return rc;
309 }
310 // read in some more data, perhaps now we can unwrap.
311 readUnderflow = false;
312 readBuffer.flip();
313 } else {
314 SSLEngineResult result = engine.unwrap(readBuffer, plain);
315 rc += result.bytesProduced();
316 if( result.getStatus() == BUFFER_OVERFLOW ) {
317 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
318 result = engine.unwrap(readBuffer, readOverflowBuffer);
319 if( readOverflowBuffer.position()==0 ) {
320 readOverflowBuffer = null;
321 } else {
322 readOverflowBuffer.flip();
323 }
324 }
325 switch( result.getStatus() ) {
326 case CLOSED:
327 if (rc==0) {
328 engine.closeInbound();
329 return -1;
330 } else {
331 return rc;
332 }
333 case OK:
334 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
335 dispatchQueue.execute(new Runnable() {
336 public void run() {
337 handshake();
338 }
339 });
340 }
341 break;
342 case BUFFER_UNDERFLOW:
343 readBuffer.compact();
344 readUnderflow = true;
345 break;
346 case BUFFER_OVERFLOW:
347 throw new AssertionError("Unexpected case.");
348 }
349 }
350 }
351 return rc;
352 }
353
354 public void handshake() {
355 try {
356 if( !transportFlush() ) {
357 return;
358 }
359 switch (engine.getHandshakeStatus()) {
360 case NEED_TASK:
361 final Runnable task = engine.getDelegatedTask();
362 if( task!=null ) {
363 blockingExecutor.execute(new Runnable() {
364 public void run() {
365 task.run();
366 dispatchQueue.execute(new Runnable() {
367 public void run() {
368 if (isConnected()) {
369 handshake();
370 }
371 }
372 });
373 }
374 });
375 }
376 break;
377
378 case NEED_WRAP:
379 secure_write(ByteBuffer.allocate(0));
380 break;
381
382 case NEED_UNWRAP:
383 if( secure_read(ByteBuffer.allocate(0)) == -1) {
384 throw new EOFException("Peer disconnected during ssl handshake");
385 }
386 break;
387
388 case FINISHED:
389 case NOT_HANDSHAKING:
390 drainOutboundSource.merge(1);
391 break;
392
393 default:
394 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
395 break;
396 }
397 } catch (IOException e ) {
398 onTransportFailure(e);
399 }
400 }
401
402
403 public ReadableByteChannel readChannel() {
404 return ssl_channel;
405 }
406
407 public WritableByteChannel writeChannel() {
408 return ssl_channel;
409 }
410
411 public Executor getBlockingExecutor() {
412 return blockingExecutor;
413 }
414
415 public void setBlockingExecutor(Executor blockingExecutor) {
416 this.blockingExecutor = blockingExecutor;
417 }
418 }
419
420