001 /**
002 * Copyright (C) 2010-2011, FuseSource Corp. All rights reserved.
003 *
004 * http://fusesource.com
005 *
006 * The software in this package is published under the terms of the
007 * CDDL license a copy of which has been included with this distribution
008 * in the license.txt file.
009 */
010 package org.fusesource.hawtdispatch.transport;
011
012 import javax.net.ssl.*;
013 import java.io.IOException;
014 import java.net.Socket;
015 import java.net.URI;
016 import java.nio.ByteBuffer;
017 import java.nio.channels.*;
018 import java.security.cert.Certificate;
019 import java.security.cert.X509Certificate;
020 import java.util.ArrayList;
021 import java.util.concurrent.Executor;
022
023 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
024 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
025 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
026 import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
027
028 /**
029 * An SSL Transport for secure communications.
030 *
031 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
032 */
033 public class SslTransport extends TcpTransport {
034
035
036 /**
037 * Maps uri schemes to a protocol algorithm names.
038 * Valid algorithm names listed at:
039 * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
040 */
041 public static String protocol(String scheme) {
042 if( scheme.equals("tls") ) {
043 return "TLS";
044 } else if( scheme.startsWith("tlsv") ) {
045 return "TLSv"+scheme.substring(4);
046 } else if( scheme.equals("ssl") ) {
047 return "SSL";
048 } else if( scheme.startsWith("sslv") ) {
049 return "SSLv"+scheme.substring(4);
050 }
051 return null;
052 }
053
054 private SSLContext sslContext;
055 private SSLEngine engine;
056
057 private ByteBuffer readBuffer;
058 private boolean readUnderflow;
059
060 private ByteBuffer writeBuffer;
061 private boolean writeFlushing;
062
063 private ByteBuffer readOverflowBuffer;
064 private SSLChannel ssl_channel = new SSLChannel();
065
066 private Executor blockingExecutor;
067
068 public void setSSLContext(SSLContext ctx) {
069 this.sslContext = ctx;
070 }
071
072 /**
073 * Allows subclasses of TcpTransportFactory to create custom instances of
074 * TcpTransport.
075 */
076 public static SslTransport createTransport(URI uri) throws Exception {
077 String protocol = protocol(uri.getScheme());
078 if( protocol !=null ) {
079 SslTransport rc = new SslTransport();
080 rc.setSSLContext(SSLContext.getInstance(protocol));
081 return rc;
082 }
083 return null;
084 }
085
086 public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
087
088 public int write(ByteBuffer plain) throws IOException {
089 return secure_write(plain);
090 }
091
092 public int read(ByteBuffer plain) throws IOException {
093 return secure_read(plain);
094 }
095
096 public boolean isOpen() {
097 return getSocketChannel().isOpen();
098 }
099
100 public void close() throws IOException {
101 getSocketChannel().close();
102 }
103
104 public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
105 if(offset+length > srcs.length || length<0 || offset<0) {
106 throw new IndexOutOfBoundsException();
107 }
108 long rc=0;
109 for (int i = 0; i < length; i++) {
110 ByteBuffer src = srcs[offset+i];
111 if(src.hasRemaining()) {
112 rc += write(src);
113 }
114 if( src.hasRemaining() ) {
115 return rc;
116 }
117 }
118 return rc;
119 }
120
121 public long write(ByteBuffer[] srcs) throws IOException {
122 return write(srcs, 0, srcs.length);
123 }
124
125 public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
126 if(offset+length > dsts.length || length<0 || offset<0) {
127 throw new IndexOutOfBoundsException();
128 }
129 long rc=0;
130 for (int i = 0; i < length; i++) {
131 ByteBuffer dst = dsts[offset+i];
132 if(dst.hasRemaining()) {
133 rc += read(dst);
134 }
135 if( dst.hasRemaining() ) {
136 return rc;
137 }
138 }
139 return rc;
140 }
141
142 public long read(ByteBuffer[] dsts) throws IOException {
143 return read(dsts, 0, dsts.length);
144 }
145
146 public Socket socket() {
147 SocketChannel c = channel;
148 if( c == null ) {
149 return null;
150 }
151 return c.socket();
152 }
153 }
154
155 public SSLSession getSSLSession() {
156 return engine==null ? null : engine.getSession();
157 }
158
159 public X509Certificate[] getPeerX509Certificates() {
160 if( engine==null ) {
161 return null;
162 }
163 try {
164 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
165 for( Certificate c:engine.getSession().getPeerCertificates() ) {
166 if(c instanceof X509Certificate) {
167 rc.add((X509Certificate) c);
168 }
169 }
170 return rc.toArray(new X509Certificate[rc.size()]);
171 } catch (SSLPeerUnverifiedException e) {
172 return null;
173 }
174 }
175
176 @Override
177 public void connecting(URI remoteLocation, URI localLocation) throws Exception {
178 assert engine == null;
179 engine = sslContext.createSSLEngine();
180 engine.setUseClientMode(true);
181 super.connecting(remoteLocation, localLocation);
182 }
183
184 @Override
185 public void connected(SocketChannel channel) throws Exception {
186 if (engine == null) {
187 engine = sslContext.createSSLEngine();
188 engine.setUseClientMode(false);
189 engine.setWantClientAuth(true);
190 }
191 SSLSession session = engine.getSession();
192 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
193 readBuffer.flip();
194 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
195
196 super.connected(channel);
197 }
198
199 @Override
200 protected void onConnected() throws IOException {
201 super.onConnected();
202 engine.setWantClientAuth(true);
203 engine.beginHandshake();
204 handshake();
205 }
206
207 @Override
208 public void flush() {
209 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
210 handshake();
211 } else {
212 super.flush();
213 }
214 }
215
216 @Override
217 protected void drainInbound() {
218 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
219 handshake();
220 } else {
221 super.drainInbound();
222 }
223 }
224
225 /**
226 * @return true if fully flushed.
227 * @throws IOException
228 */
229 protected boolean transportFlush() throws IOException {
230 while (true) {
231 if(writeFlushing) {
232 int count = super.writeChannel().write(writeBuffer);
233 if( !writeBuffer.hasRemaining() ) {
234 writeBuffer.clear();
235 writeFlushing = false;
236 suspendWrite();
237 return true;
238 } else {
239 return false;
240 }
241 } else {
242 if( writeBuffer.position()!=0 ) {
243 writeBuffer.flip();
244 writeFlushing = true;
245 resumeWrite();
246 } else {
247 return true;
248 }
249 }
250 }
251 }
252
253 private int secure_write(ByteBuffer plain) throws IOException {
254 if( !transportFlush() ) {
255 // can't write anymore until the write_secured_buffer gets fully flushed out..
256 return 0;
257 }
258 int rc = 0;
259 while ( plain.hasRemaining() || engine.getHandshakeStatus()==NEED_WRAP ) {
260 SSLEngineResult result = engine.wrap(plain, writeBuffer);
261 assert result.getStatus()!= BUFFER_OVERFLOW;
262 rc += result.bytesConsumed();
263 if( !transportFlush() ) {
264 break;
265 }
266 }
267 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
268 dispatchQueue.execute(new Runnable() {
269 public void run() {
270 handshake();
271 }
272 });
273 }
274 return rc;
275 }
276
277 private int secure_read(ByteBuffer plain) throws IOException {
278 int rc=0;
279 while ( plain.hasRemaining() || engine.getHandshakeStatus() == NEED_UNWRAP ) {
280 if( readOverflowBuffer !=null ) {
281 if( plain.hasRemaining() ) {
282 // lets drain the overflow buffer before trying to suck down anymore
283 // network bytes.
284 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
285 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
286 readOverflowBuffer.position(readOverflowBuffer.position()+size);
287 if( !readOverflowBuffer.hasRemaining() ) {
288 readOverflowBuffer = null;
289 }
290 rc += size;
291 } else {
292 return rc;
293 }
294 } else if( readUnderflow ) {
295 int count = super.readChannel().read(readBuffer);
296 if( count == -1 ) { // peer closed socket.
297 if (rc==0) {
298 return -1;
299 } else {
300 return rc;
301 }
302 }
303 if( count==0 ) { // no data available right now.
304 return rc;
305 }
306 // read in some more data, perhaps now we can unwrap.
307 readUnderflow = false;
308 readBuffer.flip();
309 } else {
310 SSLEngineResult result = engine.unwrap(readBuffer, plain);
311 rc += result.bytesProduced();
312 if( result.getStatus() == BUFFER_OVERFLOW ) {
313 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
314 result = engine.unwrap(readBuffer, readOverflowBuffer);
315 if( readOverflowBuffer.position()==0 ) {
316 readOverflowBuffer = null;
317 } else {
318 readOverflowBuffer.flip();
319 }
320 }
321 switch( result.getStatus() ) {
322 case CLOSED:
323 if (rc==0) {
324 engine.closeInbound();
325 return -1;
326 } else {
327 return rc;
328 }
329 case OK:
330 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
331 dispatchQueue.execute(new Runnable() {
332 public void run() {
333 handshake();
334 }
335 });
336 }
337 break;
338 case BUFFER_UNDERFLOW:
339 readBuffer.compact();
340 readUnderflow = true;
341 break;
342 case BUFFER_OVERFLOW:
343 throw new AssertionError("Unexpected case.");
344 }
345 }
346 }
347 return rc;
348 }
349
350 public void handshake() {
351 try {
352 if( !transportFlush() ) {
353 return;
354 }
355 switch (engine.getHandshakeStatus()) {
356 case NEED_TASK:
357 final Runnable task = engine.getDelegatedTask();
358 if( task!=null ) {
359 blockingExecutor.execute(new Runnable() {
360 public void run() {
361 task.run();
362 dispatchQueue.execute(new Runnable() {
363 public void run() {
364 if (isConnected()) {
365 handshake();
366 }
367 }
368 });
369 }
370 });
371 }
372 break;
373
374 case NEED_WRAP:
375 secure_write(ByteBuffer.allocate(0));
376 break;
377
378 case NEED_UNWRAP:
379 secure_read(ByteBuffer.allocate(0));
380 break;
381
382 case FINISHED:
383 case NOT_HANDSHAKING:
384 break;
385
386 default:
387 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
388 break;
389 }
390 } catch (IOException e ) {
391 onTransportFailure(e);
392 }
393 }
394
395
396 public ReadableByteChannel readChannel() {
397 return ssl_channel;
398 }
399
400 public WritableByteChannel writeChannel() {
401 return ssl_channel;
402 }
403
404 public Executor getBlockingExecutor() {
405 return blockingExecutor;
406 }
407
408 public void setBlockingExecutor(Executor blockingExecutor) {
409 this.blockingExecutor = blockingExecutor;
410 }
411 }
412
413