001    /**
002     * Copyright (C) 2010-2011, FuseSource Corp.  All rights reserved.
003     *
004     *     http://fusesource.com
005     *
006     * The software in this package is published under the terms of the
007     * CDDL license a copy of which has been included with this distribution
008     * in the license.txt file.
009     */
010    package org.fusesource.hawtdispatch.transport;
011    
012    import javax.net.ssl.*;
013    import java.io.IOException;
014    import java.net.Socket;
015    import java.net.URI;
016    import java.nio.ByteBuffer;
017    import java.nio.channels.*;
018    import java.security.cert.Certificate;
019    import java.security.cert.X509Certificate;
020    import java.util.ArrayList;
021    import java.util.concurrent.Executor;
022    
023    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
024    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
025    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
026    import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
027    
028    /**
029     * An SSL Transport for secure communications.
030     *
031     * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
032     */
033    public class SslTransport extends TcpTransport {
034    
035    
036        /**
037         * Maps uri schemes to a protocol algorithm names.
038         * Valid algorithm names listed at:
039         * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
040         */
041        public static String protocol(String scheme) {
042            if( scheme.equals("tls") ) {
043                return "TLS";
044            } else if( scheme.startsWith("tlsv") ) {
045                return "TLSv"+scheme.substring(4);
046            } else if( scheme.equals("ssl") ) {
047                return "SSL";
048            } else if( scheme.startsWith("sslv") ) {
049                return "SSLv"+scheme.substring(4);
050            }
051            return null;
052        }
053    
054        private SSLContext sslContext;
055        private SSLEngine engine;
056    
057        private ByteBuffer readBuffer;
058        private boolean readUnderflow;
059    
060        private ByteBuffer writeBuffer;
061        private boolean writeFlushing;
062    
063        private ByteBuffer readOverflowBuffer;
064        private SSLChannel ssl_channel = new SSLChannel();
065    
066        private Executor blockingExecutor;
067    
068        public void setSSLContext(SSLContext ctx) {
069            this.sslContext = ctx;
070        }
071    
072        /**
073         * Allows subclasses of TcpTransportFactory to create custom instances of
074         * TcpTransport.
075         */
076        public static SslTransport createTransport(URI uri) throws Exception {
077            String protocol = protocol(uri.getScheme());
078            if( protocol !=null ) {
079                SslTransport rc = new SslTransport();
080                rc.setSSLContext(SSLContext.getInstance(protocol));
081                return rc;
082            }
083            return null;
084        }
085    
086        public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
087    
088            public int write(ByteBuffer plain) throws IOException {
089                return secure_write(plain);
090            }
091    
092            public int read(ByteBuffer plain) throws IOException {
093                return secure_read(plain);
094            }
095    
096            public boolean isOpen() {
097                return getSocketChannel().isOpen();
098            }
099    
100            public void close() throws IOException {
101                getSocketChannel().close();
102            }
103    
104            public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
105                if(offset+length > srcs.length || length<0 || offset<0) {
106                    throw new IndexOutOfBoundsException();
107                }
108                long rc=0;
109                for (int i = 0; i < length; i++) {
110                    ByteBuffer src = srcs[offset+i];
111                    if(src.hasRemaining()) {
112                        rc += write(src);
113                    }
114                    if( src.hasRemaining() ) {
115                        return rc;
116                    }
117                }
118                return rc;
119            }
120    
121            public long write(ByteBuffer[] srcs) throws IOException {
122                return write(srcs, 0, srcs.length);
123            }
124    
125            public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
126                if(offset+length > dsts.length || length<0 || offset<0) {
127                    throw new IndexOutOfBoundsException();
128                }
129                long rc=0;
130                for (int i = 0; i < length; i++) {
131                    ByteBuffer dst = dsts[offset+i];
132                    if(dst.hasRemaining()) {
133                        rc += read(dst);
134                    }
135                    if( dst.hasRemaining() ) {
136                        return rc;
137                    }
138                }
139                return rc;
140            }
141    
142            public long read(ByteBuffer[] dsts) throws IOException {
143                return read(dsts, 0, dsts.length);
144            }
145            
146            public Socket socket() {
147                SocketChannel c = channel;
148                if( c == null ) {
149                    return null;
150                }
151                return c.socket();
152            }
153        }
154    
155        public SSLSession getSSLSession() {
156            return engine==null ? null : engine.getSession();
157        }
158    
159        public X509Certificate[] getPeerX509Certificates() {
160            if( engine==null ) {
161                return null;
162            }
163            try {
164                ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
165                for( Certificate c:engine.getSession().getPeerCertificates() ) {
166                    if(c instanceof X509Certificate) {
167                        rc.add((X509Certificate) c);
168                    }
169                }
170                return rc.toArray(new X509Certificate[rc.size()]);
171            } catch (SSLPeerUnverifiedException e) {
172                return null;
173            }
174        }
175    
176        @Override
177        public void connecting(URI remoteLocation, URI localLocation) throws Exception {
178            assert engine == null;
179            engine = sslContext.createSSLEngine();
180            engine.setUseClientMode(true);
181            super.connecting(remoteLocation, localLocation);
182        }
183    
184        @Override
185        public void connected(SocketChannel channel) throws Exception {
186            if (engine == null) {
187                engine = sslContext.createSSLEngine();
188                engine.setUseClientMode(false);
189                engine.setWantClientAuth(true);
190            }
191            SSLSession session = engine.getSession();
192            readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
193            readBuffer.flip();
194            writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
195    
196            super.connected(channel);
197        }
198    
199        @Override
200        protected void onConnected() throws IOException {
201            super.onConnected();
202            engine.setWantClientAuth(true);
203            engine.beginHandshake();
204            handshake();
205        }
206    
207        @Override
208        public void flush() {
209            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
210                handshake();
211            } else {
212                super.flush();
213            }
214        }
215    
216        @Override
217        protected void drainInbound() {
218            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
219                handshake();
220            } else {
221                super.drainInbound();
222            }
223        }
224    
225        /**
226         * @return true if fully flushed.
227         * @throws IOException
228         */
229        protected boolean transportFlush() throws IOException {
230            while (true) {
231                if(writeFlushing) {
232                    int count = super.writeChannel().write(writeBuffer);
233                    if( !writeBuffer.hasRemaining() ) {
234                        writeBuffer.clear();
235                        writeFlushing = false;
236                        suspendWrite();
237                        return true;
238                    } else {
239                        return false;
240                    }
241                } else {
242                    if( writeBuffer.position()!=0 ) {
243                        writeBuffer.flip();
244                        writeFlushing = true;
245                        resumeWrite();
246                    } else {
247                        return true;
248                    }
249                }
250            }
251        }
252    
253        private int secure_write(ByteBuffer plain) throws IOException {
254            if( !transportFlush() ) {
255                // can't write anymore until the write_secured_buffer gets fully flushed out..
256                return 0;
257            }
258            int rc = 0;
259            while ( plain.hasRemaining() || engine.getHandshakeStatus()==NEED_WRAP ) {
260                SSLEngineResult result = engine.wrap(plain, writeBuffer);
261                assert result.getStatus()!= BUFFER_OVERFLOW;
262                rc += result.bytesConsumed();
263                if( !transportFlush() ) {
264                    break;
265                }
266            }
267            if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
268                dispatchQueue.execute(new Runnable() {
269                    public void run() {
270                        handshake();
271                    }
272                });
273            }
274            return rc;
275        }
276    
277        private int secure_read(ByteBuffer plain) throws IOException {
278            int rc=0;
279            while ( plain.hasRemaining() || engine.getHandshakeStatus() == NEED_UNWRAP ) {
280                if( readOverflowBuffer !=null ) {
281                    if(  plain.hasRemaining() ) {
282                        // lets drain the overflow buffer before trying to suck down anymore
283                        // network bytes.
284                        int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
285                        plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
286                        readOverflowBuffer.position(readOverflowBuffer.position()+size);
287                        if( !readOverflowBuffer.hasRemaining() ) {
288                            readOverflowBuffer = null;
289                        }
290                        rc += size;
291                    } else {
292                        return rc;
293                    }
294                } else if( readUnderflow ) {
295                    int count = super.readChannel().read(readBuffer);
296                    if( count == -1 ) {  // peer closed socket.
297                        if (rc==0) {
298                            return -1;
299                        } else {
300                            return rc;
301                        }
302                    }
303                    if( count==0 ) {  // no data available right now.
304                        return rc;
305                    }
306                    // read in some more data, perhaps now we can unwrap.
307                    readUnderflow = false;
308                    readBuffer.flip();
309                } else {
310                    SSLEngineResult result = engine.unwrap(readBuffer, plain);
311                    rc += result.bytesProduced();
312                    if( result.getStatus() == BUFFER_OVERFLOW ) {
313                        readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
314                        result = engine.unwrap(readBuffer, readOverflowBuffer);
315                        if( readOverflowBuffer.position()==0 ) {
316                            readOverflowBuffer = null;
317                        } else {
318                            readOverflowBuffer.flip();
319                        }
320                    }
321                    switch( result.getStatus() ) {
322                        case CLOSED:
323                            if (rc==0) {
324                                engine.closeInbound();
325                                return -1;
326                            } else {
327                                return rc;
328                            }
329                        case OK:
330                            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
331                                dispatchQueue.execute(new Runnable() {
332                                    public void run() {
333                                        handshake();
334                                    }
335                                });
336                            }
337                            break;
338                        case BUFFER_UNDERFLOW:
339                            readBuffer.compact();
340                            readUnderflow = true;
341                            break;
342                        case BUFFER_OVERFLOW:
343                            throw new AssertionError("Unexpected case.");
344                    }
345                }
346            }
347            return rc;
348        }
349    
350        public void handshake() {
351            try {
352                if( !transportFlush() ) {
353                    return;
354                }
355                switch (engine.getHandshakeStatus()) {
356                    case NEED_TASK:
357                        final Runnable task = engine.getDelegatedTask();
358                        if( task!=null ) {
359                            blockingExecutor.execute(new Runnable() {
360                                public void run() {
361                                    task.run();
362                                    dispatchQueue.execute(new Runnable() {
363                                        public void run() {
364                                            if (isConnected()) {
365                                                handshake();
366                                            }
367                                        }
368                                    });
369                                }
370                            });
371                        }
372                        break;
373    
374                    case NEED_WRAP:
375                        secure_write(ByteBuffer.allocate(0));
376                        break;
377    
378                    case NEED_UNWRAP:
379                        secure_read(ByteBuffer.allocate(0));
380                        break;
381    
382                    case FINISHED:
383                    case NOT_HANDSHAKING:
384                        break;
385    
386                    default:
387                        System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
388                        break;
389                }
390            } catch (IOException e ) {
391                onTransportFailure(e);
392            }
393        }
394    
395    
396        public ReadableByteChannel readChannel() {
397            return ssl_channel;
398        }
399    
400        public WritableByteChannel writeChannel() {
401            return ssl_channel;
402        }
403    
404        public Executor getBlockingExecutor() {
405            return blockingExecutor;
406        }
407    
408        public void setBlockingExecutor(Executor blockingExecutor) {
409            this.blockingExecutor = blockingExecutor;
410        }
411    }
412    
413