001 /**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.fusesource.hawtdispatch.transport;
019
020 import org.fusesource.hawtdispatch.Task;
021
022 import javax.net.ssl.*;
023 import java.io.EOFException;
024 import java.io.IOException;
025 import java.net.Socket;
026 import java.net.URI;
027 import java.nio.ByteBuffer;
028 import java.nio.channels.*;
029 import java.security.cert.Certificate;
030 import java.security.cert.X509Certificate;
031 import java.util.ArrayList;
032 import java.util.concurrent.Executor;
033
034 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
035 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
036 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
037 import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
038
039 /**
040 * An SSL Transport for secure communications.
041 *
042 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
043 */
044 public class SslTransport extends TcpTransport implements SecureTransport {
045
046
047 /**
048 * Maps uri schemes to a protocol algorithm names.
049 * Valid algorithm names listed at:
050 * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
051 */
052 public static String protocol(String scheme) {
053 if( scheme.equals("tls") ) {
054 return "TLS";
055 } else if( scheme.startsWith("tlsv") ) {
056 return "TLSv"+scheme.substring(4);
057 } else if( scheme.equals("ssl") ) {
058 return "SSL";
059 } else if( scheme.startsWith("sslv") ) {
060 return "SSLv"+scheme.substring(4);
061 }
062 return null;
063 }
064
065 enum ClientAuth {
066 WANT, NEED, NONE
067 };
068
069 private ClientAuth clientAuth = ClientAuth.WANT;
070
071 private SSLContext sslContext;
072 private SSLEngine engine;
073
074 private ByteBuffer readBuffer;
075 private boolean readUnderflow;
076
077 private ByteBuffer writeBuffer;
078 private boolean writeFlushing;
079
080 private ByteBuffer readOverflowBuffer;
081 private SSLChannel ssl_channel = new SSLChannel();
082
083 private Executor blockingExecutor;
084
085 public void setSSLContext(SSLContext ctx) {
086 this.sslContext = ctx;
087 }
088
089 /**
090 * Allows subclasses of TcpTransportFactory to create custom instances of
091 * TcpTransport.
092 */
093 public static SslTransport createTransport(URI uri) throws Exception {
094 String protocol = protocol(uri.getScheme());
095 if( protocol !=null ) {
096 SslTransport rc = new SslTransport();
097 rc.setSSLContext(SSLContext.getInstance(protocol));
098 return rc;
099 }
100 return null;
101 }
102
103 public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
104
105 public int write(ByteBuffer plain) throws IOException {
106 return secure_write(plain);
107 }
108
109 public int read(ByteBuffer plain) throws IOException {
110 return secure_read(plain);
111 }
112
113 public boolean isOpen() {
114 return getSocketChannel().isOpen();
115 }
116
117 public void close() throws IOException {
118 getSocketChannel().close();
119 }
120
121 public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
122 if(offset+length > srcs.length || length<0 || offset<0) {
123 throw new IndexOutOfBoundsException();
124 }
125 long rc=0;
126 for (int i = 0; i < length; i++) {
127 ByteBuffer src = srcs[offset+i];
128 if(src.hasRemaining()) {
129 rc += write(src);
130 }
131 if( src.hasRemaining() ) {
132 return rc;
133 }
134 }
135 return rc;
136 }
137
138 public long write(ByteBuffer[] srcs) throws IOException {
139 return write(srcs, 0, srcs.length);
140 }
141
142 public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
143 if(offset+length > dsts.length || length<0 || offset<0) {
144 throw new IndexOutOfBoundsException();
145 }
146 long rc=0;
147 for (int i = 0; i < length; i++) {
148 ByteBuffer dst = dsts[offset+i];
149 if(dst.hasRemaining()) {
150 rc += read(dst);
151 }
152 if( dst.hasRemaining() ) {
153 return rc;
154 }
155 }
156 return rc;
157 }
158
159 public long read(ByteBuffer[] dsts) throws IOException {
160 return read(dsts, 0, dsts.length);
161 }
162
163 public Socket socket() {
164 SocketChannel c = channel;
165 if( c == null ) {
166 return null;
167 }
168 return c.socket();
169 }
170 }
171
172 public SSLSession getSSLSession() {
173 return engine==null ? null : engine.getSession();
174 }
175
176 public X509Certificate[] getPeerX509Certificates() {
177 if( engine==null ) {
178 return null;
179 }
180 try {
181 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
182 for( Certificate c:engine.getSession().getPeerCertificates() ) {
183 if(c instanceof X509Certificate) {
184 rc.add((X509Certificate) c);
185 }
186 }
187 return rc.toArray(new X509Certificate[rc.size()]);
188 } catch (SSLPeerUnverifiedException e) {
189 return null;
190 }
191 }
192
193 @Override
194 public void connecting(URI remoteLocation, URI localLocation) throws Exception {
195 assert engine == null;
196 engine = sslContext.createSSLEngine();
197 engine.setUseClientMode(true);
198 super.connecting(remoteLocation, localLocation);
199 }
200
201 @Override
202 public void connected(SocketChannel channel) throws Exception {
203 if (engine == null) {
204 engine = sslContext.createSSLEngine();
205 engine.setUseClientMode(false);
206 switch (clientAuth) {
207 case WANT: engine.setWantClientAuth(true); break;
208 case NEED: engine.setNeedClientAuth(true); break;
209 case NONE: engine.setWantClientAuth(false); break;
210 }
211
212 }
213 super.connected(channel);
214 }
215
216 @Override
217 protected void initializeChannel() throws Exception {
218 super.initializeChannel();
219 SSLSession session = engine.getSession();
220 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
221 readBuffer.flip();
222 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
223 }
224
225 @Override
226 protected void onConnected() throws IOException {
227 super.onConnected();
228 engine.beginHandshake();
229 handshake();
230 }
231
232 @Override
233 public void flush() {
234 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
235 handshake();
236 } else {
237 super.flush();
238 }
239 }
240
241 @Override
242 protected void drainInbound() {
243 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
244 handshake();
245 } else {
246 super.drainInbound();
247 }
248 }
249
250 /**
251 * @return true if fully flushed.
252 * @throws IOException
253 */
254 protected boolean transportFlush() throws IOException {
255 while (true) {
256 if(writeFlushing) {
257 int count = super.writeChannel().write(writeBuffer);
258 if( !writeBuffer.hasRemaining() ) {
259 writeBuffer.clear();
260 writeFlushing = false;
261 suspendWrite();
262 return true;
263 } else {
264 return false;
265 }
266 } else {
267 if( writeBuffer.position()!=0 ) {
268 writeBuffer.flip();
269 writeFlushing = true;
270 resumeWrite();
271 } else {
272 return true;
273 }
274 }
275 }
276 }
277
278 private int secure_write(ByteBuffer plain) throws IOException {
279 if( !transportFlush() ) {
280 // can't write anymore until the write_secured_buffer gets fully flushed out..
281 return 0;
282 }
283 int rc = 0;
284 while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
285 SSLEngineResult result = engine.wrap(plain, writeBuffer);
286 assert result.getStatus()!= BUFFER_OVERFLOW;
287 rc += result.bytesConsumed();
288 if( !transportFlush() ) {
289 break;
290 }
291 }
292 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
293 dispatchQueue.execute(new Task() {
294 public void run() {
295 handshake();
296 }
297 });
298 }
299 return rc;
300 }
301
302 private int secure_read(ByteBuffer plain) throws IOException {
303 int rc=0;
304 while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
305 if( readOverflowBuffer !=null ) {
306 if( plain.hasRemaining() ) {
307 // lets drain the overflow buffer before trying to suck down anymore
308 // network bytes.
309 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
310 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
311 readOverflowBuffer.position(readOverflowBuffer.position()+size);
312 if( !readOverflowBuffer.hasRemaining() ) {
313 readOverflowBuffer = null;
314 }
315 rc += size;
316 } else {
317 return rc;
318 }
319 } else if( readUnderflow ) {
320 int count = super.readChannel().read(readBuffer);
321 if( count == -1 ) { // peer closed socket.
322 if (rc==0) {
323 return -1;
324 } else {
325 return rc;
326 }
327 }
328 if( count==0 ) { // no data available right now.
329 return rc;
330 }
331 // read in some more data, perhaps now we can unwrap.
332 readUnderflow = false;
333 readBuffer.flip();
334 } else {
335 SSLEngineResult result = engine.unwrap(readBuffer, plain);
336 rc += result.bytesProduced();
337 if( result.getStatus() == BUFFER_OVERFLOW ) {
338 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
339 result = engine.unwrap(readBuffer, readOverflowBuffer);
340 if( readOverflowBuffer.position()==0 ) {
341 readOverflowBuffer = null;
342 } else {
343 readOverflowBuffer.flip();
344 }
345 }
346 switch( result.getStatus() ) {
347 case CLOSED:
348 if (rc==0) {
349 engine.closeInbound();
350 return -1;
351 } else {
352 return rc;
353 }
354 case OK:
355 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
356 dispatchQueue.execute(new Task() {
357 public void run() {
358 handshake();
359 }
360 });
361 }
362 break;
363 case BUFFER_UNDERFLOW:
364 readBuffer.compact();
365 readUnderflow = true;
366 break;
367 case BUFFER_OVERFLOW:
368 throw new AssertionError("Unexpected case.");
369 }
370 }
371 }
372 return rc;
373 }
374
375 public void handshake() {
376 try {
377 if( !transportFlush() ) {
378 return;
379 }
380 switch (engine.getHandshakeStatus()) {
381 case NEED_TASK:
382 final Runnable task = engine.getDelegatedTask();
383 if( task!=null ) {
384 blockingExecutor.execute(new Task() {
385 public void run() {
386 task.run();
387 dispatchQueue.execute(new Task() {
388 public void run() {
389 if (isConnected()) {
390 handshake();
391 }
392 }
393 });
394 }
395 });
396 }
397 break;
398
399 case NEED_WRAP:
400 secure_write(ByteBuffer.allocate(0));
401 break;
402
403 case NEED_UNWRAP:
404 if( secure_read(ByteBuffer.allocate(0)) == -1) {
405 throw new EOFException("Peer disconnected during ssl handshake");
406 }
407 break;
408
409 case FINISHED:
410 case NOT_HANDSHAKING:
411 drainOutboundSource.merge(1);
412 break;
413
414 default:
415 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
416 break;
417 }
418 } catch (IOException e ) {
419 onTransportFailure(e);
420 }
421 }
422
423
424 public ReadableByteChannel readChannel() {
425 return ssl_channel;
426 }
427
428 public WritableByteChannel writeChannel() {
429 return ssl_channel;
430 }
431
432 public Executor getBlockingExecutor() {
433 return blockingExecutor;
434 }
435
436 public void setBlockingExecutor(Executor blockingExecutor) {
437 this.blockingExecutor = blockingExecutor;
438 }
439
440 public String getClientAuth() {
441 return clientAuth.name();
442 }
443
444 public void setClientAuth(String clientAuth) {
445 this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase());
446 }
447 }
448
449