001    /**
002     * Copyright (C) 2010-2011, FuseSource Corp.  All rights reserved.
003     *
004     *     http://fusesource.com
005     *
006     * The software in this package is published under the terms of the
007     * CDDL license a copy of which has been included with this distribution
008     * in the license.txt file.
009     */
010    package org.fusesource.hawtdispatch.transport;
011    
012    import javax.net.ssl.*;
013    import java.io.EOFException;
014    import java.io.IOException;
015    import java.net.Socket;
016    import java.net.URI;
017    import java.nio.ByteBuffer;
018    import java.nio.channels.*;
019    import java.security.cert.Certificate;
020    import java.security.cert.X509Certificate;
021    import java.util.ArrayList;
022    import java.util.concurrent.Executor;
023    
024    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
025    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
026    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
027    import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
028    
029    /**
030     * An SSL Transport for secure communications.
031     *
032     * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
033     */
034    public class SslTransport extends TcpTransport {
035    
036    
037        /**
038         * Maps uri schemes to a protocol algorithm names.
039         * Valid algorithm names listed at:
040         * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
041         */
042        public static String protocol(String scheme) {
043            if( scheme.equals("tls") ) {
044                return "TLS";
045            } else if( scheme.startsWith("tlsv") ) {
046                return "TLSv"+scheme.substring(4);
047            } else if( scheme.equals("ssl") ) {
048                return "SSL";
049            } else if( scheme.startsWith("sslv") ) {
050                return "SSLv"+scheme.substring(4);
051            }
052            return null;
053        }
054    
055        private SSLContext sslContext;
056        private SSLEngine engine;
057    
058        private ByteBuffer readBuffer;
059        private boolean readUnderflow;
060    
061        private ByteBuffer writeBuffer;
062        private boolean writeFlushing;
063    
064        private ByteBuffer readOverflowBuffer;
065        private SSLChannel ssl_channel = new SSLChannel();
066    
067        private Executor blockingExecutor;
068    
069        public void setSSLContext(SSLContext ctx) {
070            this.sslContext = ctx;
071        }
072    
073        /**
074         * Allows subclasses of TcpTransportFactory to create custom instances of
075         * TcpTransport.
076         */
077        public static SslTransport createTransport(URI uri) throws Exception {
078            String protocol = protocol(uri.getScheme());
079            if( protocol !=null ) {
080                SslTransport rc = new SslTransport();
081                rc.setSSLContext(SSLContext.getInstance(protocol));
082                return rc;
083            }
084            return null;
085        }
086    
087        public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
088    
089            public int write(ByteBuffer plain) throws IOException {
090                return secure_write(plain);
091            }
092    
093            public int read(ByteBuffer plain) throws IOException {
094                return secure_read(plain);
095            }
096    
097            public boolean isOpen() {
098                return getSocketChannel().isOpen();
099            }
100    
101            public void close() throws IOException {
102                getSocketChannel().close();
103            }
104    
105            public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
106                if(offset+length > srcs.length || length<0 || offset<0) {
107                    throw new IndexOutOfBoundsException();
108                }
109                long rc=0;
110                for (int i = 0; i < length; i++) {
111                    ByteBuffer src = srcs[offset+i];
112                    if(src.hasRemaining()) {
113                        rc += write(src);
114                    }
115                    if( src.hasRemaining() ) {
116                        return rc;
117                    }
118                }
119                return rc;
120            }
121    
122            public long write(ByteBuffer[] srcs) throws IOException {
123                return write(srcs, 0, srcs.length);
124            }
125    
126            public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
127                if(offset+length > dsts.length || length<0 || offset<0) {
128                    throw new IndexOutOfBoundsException();
129                }
130                long rc=0;
131                for (int i = 0; i < length; i++) {
132                    ByteBuffer dst = dsts[offset+i];
133                    if(dst.hasRemaining()) {
134                        rc += read(dst);
135                    }
136                    if( dst.hasRemaining() ) {
137                        return rc;
138                    }
139                }
140                return rc;
141            }
142    
143            public long read(ByteBuffer[] dsts) throws IOException {
144                return read(dsts, 0, dsts.length);
145            }
146            
147            public Socket socket() {
148                SocketChannel c = channel;
149                if( c == null ) {
150                    return null;
151                }
152                return c.socket();
153            }
154        }
155    
156        public SSLSession getSSLSession() {
157            return engine==null ? null : engine.getSession();
158        }
159    
160        public X509Certificate[] getPeerX509Certificates() {
161            if( engine==null ) {
162                return null;
163            }
164            try {
165                ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
166                for( Certificate c:engine.getSession().getPeerCertificates() ) {
167                    if(c instanceof X509Certificate) {
168                        rc.add((X509Certificate) c);
169                    }
170                }
171                return rc.toArray(new X509Certificate[rc.size()]);
172            } catch (SSLPeerUnverifiedException e) {
173                return null;
174            }
175        }
176    
177        @Override
178        public void connecting(URI remoteLocation, URI localLocation) throws Exception {
179            assert engine == null;
180            engine = sslContext.createSSLEngine();
181            engine.setUseClientMode(true);
182            super.connecting(remoteLocation, localLocation);
183        }
184    
185        @Override
186        public void connected(SocketChannel channel) throws Exception {
187            if (engine == null) {
188                engine = sslContext.createSSLEngine();
189                engine.setUseClientMode(false);
190                engine.setWantClientAuth(true);
191            }
192            super.connected(channel);
193        }
194    
195        @Override
196        protected void initializeChannel() throws Exception {
197            super.initializeChannel();
198            SSLSession session = engine.getSession();
199            readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
200            readBuffer.flip();
201            writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
202        }
203    
204        @Override
205        protected void onConnected() throws IOException {
206            super.onConnected();
207            engine.beginHandshake();
208            handshake();
209        }
210    
211        @Override
212        public void flush() {
213            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
214                handshake();
215            } else {
216                super.flush();
217            }
218        }
219    
220        @Override
221        protected void drainInbound() {
222            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
223                handshake();
224            } else {
225                super.drainInbound();
226            }
227        }
228    
229        /**
230         * @return true if fully flushed.
231         * @throws IOException
232         */
233        protected boolean transportFlush() throws IOException {
234            while (true) {
235                if(writeFlushing) {
236                    int count = super.writeChannel().write(writeBuffer);
237                    if( !writeBuffer.hasRemaining() ) {
238                        writeBuffer.clear();
239                        writeFlushing = false;
240                        suspendWrite();
241                        return true;
242                    } else {
243                        return false;
244                    }
245                } else {
246                    if( writeBuffer.position()!=0 ) {
247                        writeBuffer.flip();
248                        writeFlushing = true;
249                        resumeWrite();
250                    } else {
251                        return true;
252                    }
253                }
254            }
255        }
256    
257        private int secure_write(ByteBuffer plain) throws IOException {
258            if( !transportFlush() ) {
259                // can't write anymore until the write_secured_buffer gets fully flushed out..
260                return 0;
261            }
262            int rc = 0;
263            while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
264                SSLEngineResult result = engine.wrap(plain, writeBuffer);
265                assert result.getStatus()!= BUFFER_OVERFLOW;
266                rc += result.bytesConsumed();
267                if( !transportFlush() ) {
268                    break;
269                }
270            }
271            if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
272                dispatchQueue.execute(new Runnable() {
273                    public void run() {
274                        handshake();
275                    }
276                });
277            }
278            return rc;
279        }
280    
281        private int secure_read(ByteBuffer plain) throws IOException {
282            int rc=0;
283            while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
284                if( readOverflowBuffer !=null ) {
285                    if(  plain.hasRemaining() ) {
286                        // lets drain the overflow buffer before trying to suck down anymore
287                        // network bytes.
288                        int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
289                        plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
290                        readOverflowBuffer.position(readOverflowBuffer.position()+size);
291                        if( !readOverflowBuffer.hasRemaining() ) {
292                            readOverflowBuffer = null;
293                        }
294                        rc += size;
295                    } else {
296                        return rc;
297                    }
298                } else if( readUnderflow ) {
299                    int count = super.readChannel().read(readBuffer);
300                    if( count == -1 ) {  // peer closed socket.
301                        if (rc==0) {
302                            return -1;
303                        } else {
304                            return rc;
305                        }
306                    }
307                    if( count==0 ) {  // no data available right now.
308                        return rc;
309                    }
310                    // read in some more data, perhaps now we can unwrap.
311                    readUnderflow = false;
312                    readBuffer.flip();
313                } else {
314                    SSLEngineResult result = engine.unwrap(readBuffer, plain);
315                    rc += result.bytesProduced();
316                    if( result.getStatus() == BUFFER_OVERFLOW ) {
317                        readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
318                        result = engine.unwrap(readBuffer, readOverflowBuffer);
319                        if( readOverflowBuffer.position()==0 ) {
320                            readOverflowBuffer = null;
321                        } else {
322                            readOverflowBuffer.flip();
323                        }
324                    }
325                    switch( result.getStatus() ) {
326                        case CLOSED:
327                            if (rc==0) {
328                                engine.closeInbound();
329                                return -1;
330                            } else {
331                                return rc;
332                            }
333                        case OK:
334                            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
335                                dispatchQueue.execute(new Runnable() {
336                                    public void run() {
337                                        handshake();
338                                    }
339                                });
340                            }
341                            break;
342                        case BUFFER_UNDERFLOW:
343                            readBuffer.compact();
344                            readUnderflow = true;
345                            break;
346                        case BUFFER_OVERFLOW:
347                            throw new AssertionError("Unexpected case.");
348                    }
349                }
350            }
351            return rc;
352        }
353    
354        public void handshake() {
355            try {
356                if( !transportFlush() ) {
357                    return;
358                }
359                switch (engine.getHandshakeStatus()) {
360                    case NEED_TASK:
361                        final Runnable task = engine.getDelegatedTask();
362                        if( task!=null ) {
363                            blockingExecutor.execute(new Runnable() {
364                                public void run() {
365                                    task.run();
366                                    dispatchQueue.execute(new Runnable() {
367                                        public void run() {
368                                            if (isConnected()) {
369                                                handshake();
370                                            }
371                                        }
372                                    });
373                                }
374                            });
375                        }
376                        break;
377    
378                    case NEED_WRAP:
379                        secure_write(ByteBuffer.allocate(0));
380                        break;
381    
382                    case NEED_UNWRAP:
383                        if( secure_read(ByteBuffer.allocate(0)) == -1) {
384                            throw new EOFException("Peer disconnected during ssl handshake");
385                        }
386                        break;
387    
388                    case FINISHED:
389                    case NOT_HANDSHAKING:
390                        drainOutboundSource.merge(1);
391                        break;
392    
393                    default:
394                        System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
395                        break;
396                }
397            } catch (IOException e ) {
398                onTransportFailure(e);
399            }
400        }
401    
402    
403        public ReadableByteChannel readChannel() {
404            return ssl_channel;
405        }
406    
407        public WritableByteChannel writeChannel() {
408            return ssl_channel;
409        }
410    
411        public Executor getBlockingExecutor() {
412            return blockingExecutor;
413        }
414    
415        public void setBlockingExecutor(Executor blockingExecutor) {
416            this.blockingExecutor = blockingExecutor;
417        }
418    }
419    
420