001    /**
002     * Copyright (C) 2012 FuseSource, Inc.
003     * http://fusesource.com
004     *
005     * Licensed under the Apache License, Version 2.0 (the "License");
006     * you may not use this file except in compliance with the License.
007     * You may obtain a copy of the License at
008     *
009     *    http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.fusesource.hawtdispatch.transport;
019    
020    import org.fusesource.hawtdispatch.Task;
021    
022    import javax.net.ssl.*;
023    import java.io.EOFException;
024    import java.io.IOException;
025    import java.net.Socket;
026    import java.net.URI;
027    import java.nio.ByteBuffer;
028    import java.nio.channels.*;
029    import java.security.cert.Certificate;
030    import java.security.cert.X509Certificate;
031    import java.util.ArrayList;
032    import java.util.concurrent.Executor;
033    
034    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
035    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
036    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
037    import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
038    
039    /**
040     * An SSL Transport for secure communications.
041     *
042     * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
043     */
044    public class SslTransport extends TcpTransport implements SecureTransport {
045    
046    
047        /**
048         * Maps uri schemes to a protocol algorithm names.
049         * Valid algorithm names listed at:
050         * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
051         */
052        public static String protocol(String scheme) {
053            if( scheme.equals("tls") ) {
054                return "TLS";
055            } else if( scheme.startsWith("tlsv") ) {
056                return "TLSv"+scheme.substring(4);
057            } else if( scheme.equals("ssl") ) {
058                return "SSL";
059            } else if( scheme.startsWith("sslv") ) {
060                return "SSLv"+scheme.substring(4);
061            }
062            return null;
063        }
064    
065        enum ClientAuth {
066            WANT, NEED, NONE
067        };
068    
069        private ClientAuth clientAuth = ClientAuth.WANT;
070    
071        private SSLContext sslContext;
072        private SSLEngine engine;
073    
074        private ByteBuffer readBuffer;
075        private boolean readUnderflow;
076    
077        private ByteBuffer writeBuffer;
078        private boolean writeFlushing;
079    
080        private ByteBuffer readOverflowBuffer;
081        private SSLChannel ssl_channel = new SSLChannel();
082    
083        private Executor blockingExecutor;
084    
085        public void setSSLContext(SSLContext ctx) {
086            this.sslContext = ctx;
087        }
088    
089        /**
090         * Allows subclasses of TcpTransportFactory to create custom instances of
091         * TcpTransport.
092         */
093        public static SslTransport createTransport(URI uri) throws Exception {
094            String protocol = protocol(uri.getScheme());
095            if( protocol !=null ) {
096                SslTransport rc = new SslTransport();
097                rc.setSSLContext(SSLContext.getInstance(protocol));
098                return rc;
099            }
100            return null;
101        }
102    
103        public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
104    
105            public int write(ByteBuffer plain) throws IOException {
106                return secure_write(plain);
107            }
108    
109            public int read(ByteBuffer plain) throws IOException {
110                return secure_read(plain);
111            }
112    
113            public boolean isOpen() {
114                return getSocketChannel().isOpen();
115            }
116    
117            public void close() throws IOException {
118                getSocketChannel().close();
119            }
120    
121            public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
122                if(offset+length > srcs.length || length<0 || offset<0) {
123                    throw new IndexOutOfBoundsException();
124                }
125                long rc=0;
126                for (int i = 0; i < length; i++) {
127                    ByteBuffer src = srcs[offset+i];
128                    if(src.hasRemaining()) {
129                        rc += write(src);
130                    }
131                    if( src.hasRemaining() ) {
132                        return rc;
133                    }
134                }
135                return rc;
136            }
137    
138            public long write(ByteBuffer[] srcs) throws IOException {
139                return write(srcs, 0, srcs.length);
140            }
141    
142            public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
143                if(offset+length > dsts.length || length<0 || offset<0) {
144                    throw new IndexOutOfBoundsException();
145                }
146                long rc=0;
147                for (int i = 0; i < length; i++) {
148                    ByteBuffer dst = dsts[offset+i];
149                    if(dst.hasRemaining()) {
150                        rc += read(dst);
151                    }
152                    if( dst.hasRemaining() ) {
153                        return rc;
154                    }
155                }
156                return rc;
157            }
158    
159            public long read(ByteBuffer[] dsts) throws IOException {
160                return read(dsts, 0, dsts.length);
161            }
162            
163            public Socket socket() {
164                SocketChannel c = channel;
165                if( c == null ) {
166                    return null;
167                }
168                return c.socket();
169            }
170        }
171    
172        public SSLSession getSSLSession() {
173            return engine==null ? null : engine.getSession();
174        }
175    
176        public X509Certificate[] getPeerX509Certificates() {
177            if( engine==null ) {
178                return null;
179            }
180            try {
181                ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
182                for( Certificate c:engine.getSession().getPeerCertificates() ) {
183                    if(c instanceof X509Certificate) {
184                        rc.add((X509Certificate) c);
185                    }
186                }
187                return rc.toArray(new X509Certificate[rc.size()]);
188            } catch (SSLPeerUnverifiedException e) {
189                return null;
190            }
191        }
192    
193        @Override
194        public void connecting(URI remoteLocation, URI localLocation) throws Exception {
195            assert engine == null;
196            engine = sslContext.createSSLEngine();
197            engine.setUseClientMode(true);
198            super.connecting(remoteLocation, localLocation);
199        }
200    
201        @Override
202        public void connected(SocketChannel channel) throws Exception {
203            if (engine == null) {
204                engine = sslContext.createSSLEngine();
205                engine.setUseClientMode(false);
206                switch (clientAuth) {
207                    case WANT: engine.setWantClientAuth(true); break;
208                    case NEED: engine.setNeedClientAuth(true); break;
209                    case NONE: engine.setWantClientAuth(false); break;
210                }
211    
212            }
213            super.connected(channel);
214        }
215    
216        @Override
217        protected void initializeChannel() throws Exception {
218            super.initializeChannel();
219            SSLSession session = engine.getSession();
220            readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
221            readBuffer.flip();
222            writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
223        }
224    
225        @Override
226        protected void onConnected() throws IOException {
227            super.onConnected();
228            engine.beginHandshake();
229            handshake();
230        }
231    
232        @Override
233        public void flush() {
234            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
235                handshake();
236            } else {
237                super.flush();
238            }
239        }
240    
241        @Override
242        protected void drainInbound() {
243            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
244                handshake();
245            } else {
246                super.drainInbound();
247            }
248        }
249    
250        /**
251         * @return true if fully flushed.
252         * @throws IOException
253         */
254        protected boolean transportFlush() throws IOException {
255            while (true) {
256                if(writeFlushing) {
257                    int count = super.writeChannel().write(writeBuffer);
258                    if( !writeBuffer.hasRemaining() ) {
259                        writeBuffer.clear();
260                        writeFlushing = false;
261                        suspendWrite();
262                        return true;
263                    } else {
264                        return false;
265                    }
266                } else {
267                    if( writeBuffer.position()!=0 ) {
268                        writeBuffer.flip();
269                        writeFlushing = true;
270                        resumeWrite();
271                    } else {
272                        return true;
273                    }
274                }
275            }
276        }
277    
278        private int secure_write(ByteBuffer plain) throws IOException {
279            if( !transportFlush() ) {
280                // can't write anymore until the write_secured_buffer gets fully flushed out..
281                return 0;
282            }
283            int rc = 0;
284            while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
285                SSLEngineResult result = engine.wrap(plain, writeBuffer);
286                assert result.getStatus()!= BUFFER_OVERFLOW;
287                rc += result.bytesConsumed();
288                if( !transportFlush() ) {
289                    break;
290                }
291            }
292            if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
293                dispatchQueue.execute(new Task() {
294                    public void run() {
295                        handshake();
296                    }
297                });
298            }
299            return rc;
300        }
301    
302        private int secure_read(ByteBuffer plain) throws IOException {
303            int rc=0;
304            while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
305                if( readOverflowBuffer !=null ) {
306                    if(  plain.hasRemaining() ) {
307                        // lets drain the overflow buffer before trying to suck down anymore
308                        // network bytes.
309                        int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
310                        plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
311                        readOverflowBuffer.position(readOverflowBuffer.position()+size);
312                        if( !readOverflowBuffer.hasRemaining() ) {
313                            readOverflowBuffer = null;
314                        }
315                        rc += size;
316                    } else {
317                        return rc;
318                    }
319                } else if( readUnderflow ) {
320                    int count = super.readChannel().read(readBuffer);
321                    if( count == -1 ) {  // peer closed socket.
322                        if (rc==0) {
323                            return -1;
324                        } else {
325                            return rc;
326                        }
327                    }
328                    if( count==0 ) {  // no data available right now.
329                        return rc;
330                    }
331                    // read in some more data, perhaps now we can unwrap.
332                    readUnderflow = false;
333                    readBuffer.flip();
334                } else {
335                    SSLEngineResult result = engine.unwrap(readBuffer, plain);
336                    rc += result.bytesProduced();
337                    if( result.getStatus() == BUFFER_OVERFLOW ) {
338                        readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
339                        result = engine.unwrap(readBuffer, readOverflowBuffer);
340                        if( readOverflowBuffer.position()==0 ) {
341                            readOverflowBuffer = null;
342                        } else {
343                            readOverflowBuffer.flip();
344                        }
345                    }
346                    switch( result.getStatus() ) {
347                        case CLOSED:
348                            if (rc==0) {
349                                engine.closeInbound();
350                                return -1;
351                            } else {
352                                return rc;
353                            }
354                        case OK:
355                            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
356                                dispatchQueue.execute(new Task() {
357                                    public void run() {
358                                        handshake();
359                                    }
360                                });
361                            }
362                            break;
363                        case BUFFER_UNDERFLOW:
364                            readBuffer.compact();
365                            readUnderflow = true;
366                            break;
367                        case BUFFER_OVERFLOW:
368                            throw new AssertionError("Unexpected case.");
369                    }
370                }
371            }
372            return rc;
373        }
374    
375        public void handshake() {
376            try {
377                if( !transportFlush() ) {
378                    return;
379                }
380                switch (engine.getHandshakeStatus()) {
381                    case NEED_TASK:
382                        final Runnable task = engine.getDelegatedTask();
383                        if( task!=null ) {
384                            blockingExecutor.execute(new Task() {
385                                public void run() {
386                                    task.run();
387                                    dispatchQueue.execute(new Task() {
388                                        public void run() {
389                                            if (isConnected()) {
390                                                handshake();
391                                            }
392                                        }
393                                    });
394                                }
395                            });
396                        }
397                        break;
398    
399                    case NEED_WRAP:
400                        secure_write(ByteBuffer.allocate(0));
401                        break;
402    
403                    case NEED_UNWRAP:
404                        if( secure_read(ByteBuffer.allocate(0)) == -1) {
405                            throw new EOFException("Peer disconnected during ssl handshake");
406                        }
407                        break;
408    
409                    case FINISHED:
410                    case NOT_HANDSHAKING:
411                        drainOutboundSource.merge(1);
412                        break;
413    
414                    default:
415                        System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
416                        break;
417                }
418            } catch (IOException e ) {
419                onTransportFailure(e);
420            }
421        }
422    
423    
424        public ReadableByteChannel readChannel() {
425            return ssl_channel;
426        }
427    
428        public WritableByteChannel writeChannel() {
429            return ssl_channel;
430        }
431    
432        public Executor getBlockingExecutor() {
433            return blockingExecutor;
434        }
435    
436        public void setBlockingExecutor(Executor blockingExecutor) {
437            this.blockingExecutor = blockingExecutor;
438        }
439    
440        public String getClientAuth() {
441            return clientAuth.name();
442        }
443    
444        public void setClientAuth(String clientAuth) {
445            this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase());
446        }
447    }
448    
449