001/**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *    http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.fusesource.hawtdispatch.transport;
019
020import org.fusesource.hawtdispatch.Task;
021
022import javax.net.ssl.*;
023import java.io.EOFException;
024import java.io.IOException;
025import java.net.Socket;
026import java.net.URI;
027import java.nio.ByteBuffer;
028import java.nio.channels.*;
029import java.security.cert.Certificate;
030import java.security.cert.X509Certificate;
031import java.util.ArrayList;
032
033import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
034import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
035import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
036import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
037
038/**
039 * An SSL Transport for secure communications.
040 *
041 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
042 */
043public class SslTransport extends TcpTransport implements SecuredSession {
044
045
046    /**
047     * Maps uri schemes to a protocol algorithm names.
048     * Valid algorithm names listed at:
049     * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
050     */
051    public static String protocol(String scheme) {
052        if( scheme.equals("tls") ) {
053            return "TLS";
054        } else if( scheme.startsWith("tlsv") ) {
055            return "TLSv"+scheme.substring(4);
056        } else if( scheme.equals("ssl") ) {
057            return "SSL";
058        } else if( scheme.startsWith("sslv") ) {
059            return "SSLv"+scheme.substring(4);
060        }
061        return null;
062    }
063
064    enum ClientAuth {
065        WANT, NEED, NONE
066    };
067
068    private ClientAuth clientAuth = ClientAuth.WANT;
069
070    private SSLContext sslContext;
071    private SSLEngine engine;
072
073    private ByteBuffer readBuffer;
074    private boolean readUnderflow;
075
076    private ByteBuffer writeBuffer;
077    private boolean writeFlushing;
078
079    private ByteBuffer readOverflowBuffer;
080    private SSLChannel ssl_channel = new SSLChannel();
081
082
083    public void setSSLContext(SSLContext ctx) {
084        this.sslContext = ctx;
085    }
086
087    /**
088     * Allows subclasses of TcpTransportFactory to create custom instances of
089     * TcpTransport.
090     */
091    public static SslTransport createTransport(URI uri) throws Exception {
092        String protocol = protocol(uri.getScheme());
093        if( protocol !=null ) {
094            SslTransport rc = new SslTransport();
095            rc.setSSLContext(SSLContext.getInstance(protocol));
096            return rc;
097        }
098        return null;
099    }
100
101    public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
102
103        public int write(ByteBuffer plain) throws IOException {
104            return secure_write(plain);
105        }
106
107        public int read(ByteBuffer plain) throws IOException {
108            return secure_read(plain);
109        }
110
111        public boolean isOpen() {
112            return getSocketChannel().isOpen();
113        }
114
115        public void close() throws IOException {
116            getSocketChannel().close();
117        }
118
119        public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
120            if(offset+length > srcs.length || length<0 || offset<0) {
121                throw new IndexOutOfBoundsException();
122            }
123            long rc=0;
124            for (int i = 0; i < length; i++) {
125                ByteBuffer src = srcs[offset+i];
126                if(src.hasRemaining()) {
127                    rc += write(src);
128                }
129                if( src.hasRemaining() ) {
130                    return rc;
131                }
132            }
133            return rc;
134        }
135
136        public long write(ByteBuffer[] srcs) throws IOException {
137            return write(srcs, 0, srcs.length);
138        }
139
140        public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
141            if(offset+length > dsts.length || length<0 || offset<0) {
142                throw new IndexOutOfBoundsException();
143            }
144            long rc=0;
145            for (int i = 0; i < length; i++) {
146                ByteBuffer dst = dsts[offset+i];
147                if(dst.hasRemaining()) {
148                    rc += read(dst);
149                }
150                if( dst.hasRemaining() ) {
151                    return rc;
152                }
153            }
154            return rc;
155        }
156
157        public long read(ByteBuffer[] dsts) throws IOException {
158            return read(dsts, 0, dsts.length);
159        }
160        
161        public Socket socket() {
162            SocketChannel c = channel;
163            if( c == null ) {
164                return null;
165            }
166            return c.socket();
167        }
168    }
169
170    public SSLSession getSSLSession() {
171        return engine==null ? null : engine.getSession();
172    }
173
174    public X509Certificate[] getPeerX509Certificates() {
175        if( engine==null ) {
176            return null;
177        }
178        try {
179            ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
180            for( Certificate c:engine.getSession().getPeerCertificates() ) {
181                if(c instanceof X509Certificate) {
182                    rc.add((X509Certificate) c);
183                }
184            }
185            return rc.toArray(new X509Certificate[rc.size()]);
186        } catch (SSLPeerUnverifiedException e) {
187            return null;
188        }
189    }
190
191    @Override
192    public void connecting(URI remoteLocation, URI localLocation) throws Exception {
193        assert engine == null;
194        engine = sslContext.createSSLEngine(remoteLocation.getHost(), remoteLocation.getPort());
195        engine.setUseClientMode(true);
196        super.connecting(remoteLocation, localLocation);
197    }
198
199    @Override
200    public void connected(SocketChannel channel) throws Exception {
201        if (engine == null) {
202            engine = sslContext.createSSLEngine();
203            engine.setUseClientMode(false);
204            switch (clientAuth) {
205                case WANT: engine.setWantClientAuth(true); break;
206                case NEED: engine.setNeedClientAuth(true); break;
207                case NONE: engine.setWantClientAuth(false); break;
208            }
209
210        }
211        super.connected(channel);
212    }
213
214    @Override
215    protected void initializeChannel() throws Exception {
216        super.initializeChannel();
217        SSLSession session = engine.getSession();
218        readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
219        readBuffer.flip();
220        writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
221    }
222
223    @Override
224    protected void onConnected() throws IOException {
225        super.onConnected();
226        engine.beginHandshake();
227        handshake();
228    }
229
230    @Override
231    public void flush() {
232        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
233            handshake();
234        } else {
235            super.flush();
236        }
237    }
238
239    @Override
240    public void drainInbound() {
241        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
242            handshake();
243        } else {
244            super.drainInbound();
245        }
246    }
247
248    /**
249     * @return true if fully flushed.
250     * @throws IOException
251     */
252    protected boolean transportFlush() throws IOException {
253        while (true) {
254            if(writeFlushing) {
255                int count = super.getWriteChannel().write(writeBuffer);
256                if( !writeBuffer.hasRemaining() ) {
257                    writeBuffer.clear();
258                    writeFlushing = false;
259                    suspendWrite();
260                    return true;
261                } else {
262                    return false;
263                }
264            } else {
265                if( writeBuffer.position()!=0 ) {
266                    writeBuffer.flip();
267                    writeFlushing = true;
268                    resumeWrite();
269                } else {
270                    return true;
271                }
272            }
273        }
274    }
275
276    private int secure_write(ByteBuffer plain) throws IOException {
277        if( !transportFlush() ) {
278            // can't write anymore until the write_secured_buffer gets fully flushed out..
279            return 0;
280        }
281        int rc = 0;
282        while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
283            SSLEngineResult result = engine.wrap(plain, writeBuffer);
284            assert result.getStatus()!= BUFFER_OVERFLOW;
285            rc += result.bytesConsumed();
286            if( !transportFlush() ) {
287                break;
288            }
289        }
290        if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
291            dispatchQueue.execute(new Task() {
292                public void run() {
293                    handshake();
294                }
295            });
296        }
297        return rc;
298    }
299
300    private int secure_read(ByteBuffer plain) throws IOException {
301        int rc=0;
302        while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
303            if( readOverflowBuffer !=null ) {
304                if(  plain.hasRemaining() ) {
305                    // lets drain the overflow buffer before trying to suck down anymore
306                    // network bytes.
307                    int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
308                    plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
309                    readOverflowBuffer.position(readOverflowBuffer.position()+size);
310                    if( !readOverflowBuffer.hasRemaining() ) {
311                        readOverflowBuffer = null;
312                    }
313                    rc += size;
314                } else {
315                    return rc;
316                }
317            } else if( readUnderflow ) {
318                int count = super.getReadChannel().read(readBuffer);
319                if( count == -1 ) {  // peer closed socket.
320                    if (rc==0) {
321                        return -1;
322                    } else {
323                        return rc;
324                    }
325                }
326                if( count==0 ) {  // no data available right now.
327                    return rc;
328                }
329                // read in some more data, perhaps now we can unwrap.
330                readUnderflow = false;
331                readBuffer.flip();
332            } else {
333                SSLEngineResult result = engine.unwrap(readBuffer, plain);
334                rc += result.bytesProduced();
335                if( result.getStatus() == BUFFER_OVERFLOW ) {
336                    readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
337                    result = engine.unwrap(readBuffer, readOverflowBuffer);
338                    if( readOverflowBuffer.position()==0 ) {
339                        readOverflowBuffer = null;
340                    } else {
341                        readOverflowBuffer.flip();
342                    }
343                }
344                switch( result.getStatus() ) {
345                    case CLOSED:
346                        if (rc==0) {
347                            engine.closeInbound();
348                            return -1;
349                        } else {
350                            return rc;
351                        }
352                    case OK:
353                        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
354                            dispatchQueue.execute(new Task() {
355                                public void run() {
356                                    handshake();
357                                }
358                            });
359                        }
360                        break;
361                    case BUFFER_UNDERFLOW:
362                        readBuffer.compact();
363                        readUnderflow = true;
364                        break;
365                    case BUFFER_OVERFLOW:
366                        throw new AssertionError("Unexpected case.");
367                }
368            }
369        }
370        return rc;
371    }
372
373    public void handshake() {
374        try {
375            if( !transportFlush() ) {
376                return;
377            }
378            switch (engine.getHandshakeStatus()) {
379                case NEED_TASK:
380                    final Runnable task = engine.getDelegatedTask();
381                    if( task!=null ) {
382                        blockingExecutor.execute(new Task() {
383                            public void run() {
384                                task.run();
385                                dispatchQueue.execute(new Task() {
386                                    public void run() {
387                                        if (isConnected()) {
388                                            handshake();
389                                        }
390                                    }
391                                });
392                            }
393                        });
394                    }
395                    break;
396
397                case NEED_WRAP:
398                    secure_write(ByteBuffer.allocate(0));
399                    break;
400
401                case NEED_UNWRAP:
402                    if( secure_read(ByteBuffer.allocate(0)) == -1) {
403                        throw new EOFException("Peer disconnected during ssl handshake");
404                    }
405                    break;
406
407                case FINISHED:
408                case NOT_HANDSHAKING:
409                    drainOutboundSource.merge(1);
410                    drainInbound();
411                    break;
412
413                default:
414                    System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
415                    break;
416            }
417        } catch (IOException e ) {
418            onTransportFailure(e);
419        }
420    }
421
422
423    public ReadableByteChannel getReadChannel() {
424        return ssl_channel;
425    }
426
427    public WritableByteChannel getWriteChannel() {
428        return ssl_channel;
429    }
430
431    public String getClientAuth() {
432        return clientAuth.name();
433    }
434
435    public void setClientAuth(String clientAuth) {
436        this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase());
437    }
438}
439
440