001/**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *    http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.fusesource.hawtdispatch.transport;
019
020import org.fusesource.hawtdispatch.*;
021
022import java.io.IOException;
023import java.net.*;
024import java.nio.ByteBuffer;
025import java.nio.channels.ReadableByteChannel;
026import java.nio.channels.SelectionKey;
027import java.nio.channels.SocketChannel;
028import java.nio.channels.WritableByteChannel;
029import java.util.LinkedList;
030import java.util.concurrent.Executor;
031import java.util.concurrent.TimeUnit;
032
033/**
034 * An implementation of the {@link org.fusesource.hawtdispatch.transport.Transport} interface using raw tcp/ip
035 *
036 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
037 */
038public class TcpTransport extends ServiceBase implements Transport {
039
040    static InetAddress localhost;
041    synchronized static public InetAddress getLocalHost() throws UnknownHostException {
042        // cache it...
043        if( localhost==null ) {
044            // this can be slow on some systems and we use repeatedly.
045            localhost = InetAddress.getLocalHost();
046        }
047        return localhost;
048    }
049
050    abstract static class SocketState {
051        void onStop(Task onCompleted) {
052        }
053        void onCanceled() {
054        }
055        boolean is(Class<? extends SocketState> clazz) {
056            return getClass()==clazz;
057        }
058    }
059
060    static class DISCONNECTED extends SocketState{}
061
062    class CONNECTING extends SocketState{
063        void onStop(Task onCompleted) {
064            trace("CONNECTING.onStop");
065            CANCELING state = new CANCELING();
066            socketState = state;
067            state.onStop(onCompleted);
068        }
069        void onCanceled() {
070            trace("CONNECTING.onCanceled");
071            CANCELING state = new CANCELING();
072            socketState = state;
073            state.onCanceled();
074        }
075    }
076
077    class CONNECTED extends SocketState {
078
079        public CONNECTED() {
080            localAddress = channel.socket().getLocalSocketAddress();
081            remoteAddress = channel.socket().getRemoteSocketAddress();
082        }
083
084        void onStop(Task onCompleted) {
085            trace("CONNECTED.onStop");
086            CANCELING state = new CANCELING();
087            socketState = state;
088            state.add(createDisconnectTask());
089            state.onStop(onCompleted);
090        }
091        void onCanceled() {
092            trace("CONNECTED.onCanceled");
093            CANCELING state = new CANCELING();
094            socketState = state;
095            state.add(createDisconnectTask());
096            state.onCanceled();
097        }
098        Task createDisconnectTask() {
099            return new Task(){
100                public void run() {
101                    listener.onTransportDisconnected();
102                }
103            };
104        }
105    }
106
107    class CANCELING extends SocketState {
108        private LinkedList<Task> runnables =  new LinkedList<Task>();
109        private int remaining;
110        private boolean dispose;
111
112        public CANCELING() {
113            if( readSource!=null ) {
114                remaining++;
115                readSource.cancel();
116            }
117            if( writeSource!=null ) {
118                remaining++;
119                writeSource.cancel();
120            }
121        }
122        void onStop(Task onCompleted) {
123            trace("CANCELING.onCompleted");
124            add(onCompleted);
125            dispose = true;
126        }
127        void add(Task onCompleted) {
128            if( onCompleted!=null ) {
129                runnables.add(onCompleted);
130            }
131        }
132        void onCanceled() {
133            trace("CANCELING.onCanceled");
134            remaining--;
135            if( remaining!=0 ) {
136                return;
137            }
138            try {
139                if( closeOnCancel ) {
140                    channel.close();
141                }
142            } catch (IOException ignore) {
143            }
144            socketState = new CANCELED(dispose);
145            for (Task runnable : runnables) {
146                runnable.run();
147            }
148            if (dispose) {
149                dispose();
150            }
151        }
152    }
153
154    class CANCELED extends SocketState {
155        private boolean disposed;
156
157        public CANCELED(boolean disposed) {
158            this.disposed=disposed;
159        }
160
161        void onStop(Task onCompleted) {
162            trace("CANCELED.onStop");
163            if( !disposed ) {
164                disposed = true;
165                dispose();
166            }
167            onCompleted.run();
168        }
169    }
170
171    protected URI remoteLocation;
172    protected URI localLocation;
173    protected TransportListener listener;
174    protected ProtocolCodec codec;
175
176    protected SocketChannel channel;
177
178    protected SocketState socketState = new DISCONNECTED();
179
180    protected DispatchQueue dispatchQueue;
181    private DispatchSource readSource;
182    private DispatchSource writeSource;
183    protected CustomDispatchSource<Integer, Integer> drainOutboundSource;
184    protected CustomDispatchSource<Integer, Integer> yieldSource;
185
186    protected boolean useLocalHost = true;
187
188    int maxReadRate;
189    int maxWriteRate;
190    int receiveBufferSize = 1024*64;
191    int sendBufferSize = 1024*64;
192    boolean closeOnCancel = true;
193
194    boolean keepAlive = true;
195
196    public static final int IPTOS_LOWCOST = 0x02;
197    public static final int IPTOS_RELIABILITY = 0x04;
198    public static final int IPTOS_THROUGHPUT = 0x08;
199    public static final int IPTOS_LOWDELAY = 0x10;
200
201    int trafficClass = IPTOS_THROUGHPUT;
202
203    protected RateLimitingChannel rateLimitingChannel;
204    SocketAddress localAddress;
205    SocketAddress remoteAddress;
206    protected Executor blockingExecutor;
207
208    class RateLimitingChannel implements ReadableByteChannel, WritableByteChannel {
209
210        int read_allowance = maxReadRate;
211        boolean read_suspended = false;
212        int read_resume_counter = 0;
213        int write_allowance = maxWriteRate;
214        boolean write_suspended = false;
215
216        public void resetAllowance() {
217            if( read_allowance != maxReadRate || write_allowance != maxWriteRate) {
218                read_allowance = maxReadRate;
219                write_allowance = maxWriteRate;
220                if( write_suspended ) {
221                    write_suspended = false;
222                    resumeWrite();
223                }
224                if( read_suspended ) {
225                    read_suspended = false;
226                    resumeRead();
227                    for( int i=0; i < read_resume_counter ; i++ ) {
228                        resumeRead();
229                    }
230                }
231            }
232        }
233
234        public int read(ByteBuffer dst) throws IOException {
235            if( maxReadRate ==0 ) {
236                return channel.read(dst);
237            } else {
238                int remaining = dst.remaining();
239                if( read_allowance ==0 || remaining ==0 ) {
240                    return 0;
241                }
242
243                int reduction = 0;
244                if( remaining > read_allowance) {
245                    reduction = remaining - read_allowance;
246                    dst.limit(dst.limit() - reduction);
247                }
248                int rc=0;
249                try {
250                    rc = channel.read(dst);
251                    read_allowance -= rc;
252                } finally {
253                    if( reduction!=0 ) {
254                        if( dst.remaining() == 0 ) {
255                            // we need to suspend the read now until we get
256                            // a new allowance..
257                            readSource.suspend();
258                            read_suspended = true;
259                        }
260                        dst.limit(dst.limit() + reduction);
261                    }
262                }
263                return rc;
264            }
265        }
266
267        public int write(ByteBuffer src) throws IOException {
268            if( maxWriteRate ==0 ) {
269                return channel.write(src);
270            } else {
271                int remaining = src.remaining();
272                if( write_allowance ==0 || remaining ==0 ) {
273                    return 0;
274                }
275
276                int reduction = 0;
277                if( remaining > write_allowance) {
278                    reduction = remaining - write_allowance;
279                    src.limit(src.limit() - reduction);
280                }
281                int rc = 0;
282                try {
283                    rc = channel.write(src);
284                    write_allowance -= rc;
285                } finally {
286                    if( reduction!=0 ) {
287                        if( src.remaining() == 0 ) {
288                            // we need to suspend the read now until we get
289                            // a new allowance..
290                            write_suspended = true;
291                            suspendWrite();
292                        }
293                        src.limit(src.limit() + reduction);
294                    }
295                }
296                return rc;
297            }
298        }
299
300        public boolean isOpen() {
301            return channel.isOpen();
302        }
303
304        public void close() throws IOException {
305            channel.close();
306        }
307
308        public void resumeRead() {
309            if( read_suspended ) {
310                read_resume_counter += 1;
311            } else {
312                _resumeRead();
313            }
314        }
315
316    }
317
318    private final Task CANCEL_HANDLER = new Task() {
319        public void run() {
320            socketState.onCanceled();
321        }
322    };
323
324    static final class OneWay {
325        final Object command;
326        final Retained retained;
327
328        public OneWay(Object command, Retained retained) {
329            this.command = command;
330            this.retained = retained;
331        }
332    }
333
334    public void connected(SocketChannel channel) throws IOException, Exception {
335        this.channel = channel;
336        initializeChannel();
337        this.socketState = new CONNECTED();
338    }
339
340    protected void initializeChannel() throws Exception {
341        this.channel.configureBlocking(false);
342        Socket socket = channel.socket();
343        try {
344            socket.setReuseAddress(true);
345        } catch (SocketException e) {
346        }
347        try {
348            socket.setSoLinger(true, 0);
349        } catch (SocketException e) {
350        }
351        try {
352            socket.setTrafficClass(trafficClass);
353        } catch (SocketException e) {
354        }
355        try {
356            socket.setKeepAlive(keepAlive);
357        } catch (SocketException e) {
358        }
359        try {
360            socket.setTcpNoDelay(true);
361        } catch (SocketException e) {
362        }
363        try {
364            socket.setReceiveBufferSize(receiveBufferSize);
365        } catch (SocketException e) {
366        }
367        try {
368            socket.setSendBufferSize(sendBufferSize);
369        } catch (SocketException e) {
370        }
371
372        if( channel!=null && codec!=null ) {
373            initializeCodec();
374        }
375    }
376
377    protected void initializeCodec() throws Exception {
378        codec.setTransport(this);
379    }
380
381    public void connecting(final URI remoteLocation, final URI localLocation) throws Exception {
382        this.channel = SocketChannel.open();
383        initializeChannel();
384        this.remoteLocation = remoteLocation;
385        this.localLocation = localLocation;
386        socketState = new CONNECTING();
387    }
388
389
390    public DispatchQueue getDispatchQueue() {
391        return dispatchQueue;
392    }
393
394    public void setDispatchQueue(DispatchQueue queue) {
395        this.dispatchQueue = queue;
396        if(readSource!=null) readSource.setTargetQueue(queue);
397        if(writeSource!=null) writeSource.setTargetQueue(queue);
398        if(drainOutboundSource!=null) drainOutboundSource.setTargetQueue(queue);
399        if(yieldSource!=null) yieldSource.setTargetQueue(queue);
400    }
401
402    public void _start(Task onCompleted) {
403        try {
404            if (socketState.is(CONNECTING.class)) {
405
406                // Resolving host names might block.. so do it on the blocking executor.
407                this.blockingExecutor.execute(new Runnable() {
408                    public void run() {
409                        try {
410
411                            final InetSocketAddress localAddress = (localLocation != null) ?
412                                    new InetSocketAddress(InetAddress.getByName(localLocation.getHost()), localLocation.getPort())
413                                    : null;
414
415                            String host = resolveHostName(remoteLocation.getHost());
416                            final InetSocketAddress remoteAddress = new InetSocketAddress(host, remoteLocation.getPort());
417
418                            // Done resolving.. switch back to the dispatch queue.
419                            dispatchQueue.execute(new Task() {
420                                @Override
421                                public void run() {
422                                    // No need to complete if we have been canceled.
423                                    if( ! socketState.is(CONNECTING.class) ) {
424                                        return;
425                                    }
426                                    try {
427
428                                        if (localAddress != null) {
429                                            channel.socket().bind(localAddress);
430                                        }
431                                        trace("connecting...");
432                                        channel.connect(remoteAddress);
433
434                                        // this allows the connect to complete..
435                                        readSource = Dispatch.createSource(channel, SelectionKey.OP_CONNECT, dispatchQueue);
436                                        readSource.setEventHandler(new Task() {
437                                            public void run() {
438                                                if (getServiceState() != STARTED) {
439                                                    return;
440                                                }
441                                                try {
442                                                    trace("connected.");
443                                                    channel.finishConnect();
444                                                    readSource.setCancelHandler(null);
445                                                    readSource.cancel();
446                                                    readSource = null;
447                                                    socketState = new CONNECTED();
448                                                    onConnected();
449                                                } catch (IOException e) {
450                                                    onTransportFailure(e);
451                                                }
452                                            }
453                                        });
454                                        readSource.setCancelHandler(CANCEL_HANDLER);
455                                        readSource.resume();
456
457                                    } catch (IOException e) {
458                                        try {
459                                            channel.close();
460                                        } catch (IOException ignore) {
461                                        }
462                                        socketState = new CANCELED(true);
463                                        listener.onTransportFailure(e);
464                                    }
465                                }
466                            });
467
468                        } catch (final IOException e) {
469                            // we're in blockingExecutor thread context here
470                            dispatchQueue.execute(new Task() {
471                                public void run() {
472                                    try {
473                                        channel.close();
474                                    } catch (IOException ignore) {
475                                    }
476                                    socketState = new CANCELED(true);
477                                    listener.onTransportFailure(e);
478                                }
479                            });
480                        }
481                    }
482                });
483            } else if (socketState.is(CONNECTED.class)) {
484                dispatchQueue.execute(new Task() {
485                    public void run() {
486                        try {
487                            trace("was connected.");
488                            onConnected();
489                        } catch (IOException e) {
490                            onTransportFailure(e);
491                        }
492                    }
493                });
494            } else {
495                System.err.println("cannot be started.  socket state is: " + socketState);
496            }
497        } finally {
498            if (onCompleted != null) {
499                onCompleted.run();
500            }
501        }
502    }
503
504    public void _stop(final Task onCompleted) {
505        trace("stopping.. at state: "+socketState);
506        socketState.onStop(onCompleted);
507    }
508
509    protected String resolveHostName(String host) throws UnknownHostException {
510        if (isUseLocalHost()) {
511            String localName = getLocalHost().getHostName();
512            if (localName != null && localName.equals(host)) {
513                return "localhost";
514            }
515        }
516        return host;
517    }
518
519    protected void onConnected() throws IOException {
520        yieldSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
521        yieldSource.setEventHandler(new Task() {
522            public void run() {
523                drainInbound();
524            }
525        });
526        yieldSource.resume();
527        drainOutboundSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
528        drainOutboundSource.setEventHandler(new Task() {
529            public void run() {
530                flush();
531            }
532        });
533        drainOutboundSource.resume();
534
535        readSource = Dispatch.createSource(channel, SelectionKey.OP_READ, dispatchQueue);
536        writeSource = Dispatch.createSource(channel, SelectionKey.OP_WRITE, dispatchQueue);
537
538        readSource.setCancelHandler(CANCEL_HANDLER);
539        writeSource.setCancelHandler(CANCEL_HANDLER);
540
541        readSource.setEventHandler(new Task() {
542            public void run() {
543                drainInbound();
544            }
545        });
546        writeSource.setEventHandler(new Task() {
547            public void run() {
548                flush();
549            }
550        });
551
552        if( maxReadRate !=0 || maxWriteRate !=0 ) {
553            rateLimitingChannel = new RateLimitingChannel();
554            schedualRateAllowanceReset();
555        }
556        listener.onTransportConnected();
557    }
558
559    private void schedualRateAllowanceReset() {
560        dispatchQueue.executeAfter(1, TimeUnit.SECONDS, new Task(){
561            public void run() {
562                if( !socketState.is(CONNECTED.class) ) {
563                    return;
564                }
565                rateLimitingChannel.resetAllowance();
566                schedualRateAllowanceReset();
567            }
568        });
569    }
570
571    private void dispose() {
572        if( readSource!=null ) {
573            readSource.cancel();
574            readSource=null;
575        }
576
577        if( writeSource!=null ) {
578            writeSource.cancel();
579            writeSource=null;
580        }
581    }
582
583    public void onTransportFailure(IOException error) {
584        listener.onTransportFailure(error);
585        socketState.onCanceled();
586    }
587
588
589    public boolean full() {
590        return codec==null ||
591               codec.full() ||
592               !socketState.is(CONNECTED.class) ||
593               getServiceState() != STARTED;
594    }
595
596    boolean rejectingOffers;
597
598    public boolean offer(Object command) {
599        dispatchQueue.assertExecuting();
600        if( full() ) {
601            return false;
602        }
603        try {
604            ProtocolCodec.BufferState rc = codec.write(command);
605            rejectingOffers = codec.full();
606            switch (rc ) {
607                case FULL:
608                    return false;
609                default:
610                    drainOutboundSource.merge(1);
611            }
612        } catch (IOException e) {
613            onTransportFailure(e);
614        }
615        return true;
616    }
617
618    boolean writeResumedForCodecFlush = false;
619
620    /**
621     *
622     */
623    public void flush() {
624        dispatchQueue.assertExecuting();
625        if (getServiceState() != STARTED || !socketState.is(CONNECTED.class)) {
626            return;
627        }
628        try {
629            if( codec.flush() == ProtocolCodec.BufferState.EMPTY && transportFlush() ) {
630                if( writeResumedForCodecFlush) {
631                    writeResumedForCodecFlush = false;
632                    suspendWrite();
633                }
634                rejectingOffers = false;
635                listener.onRefill();
636
637            } else {
638                if(!writeResumedForCodecFlush) {
639                    writeResumedForCodecFlush = true;
640                    resumeWrite();
641                }
642            }
643        } catch (IOException e) {
644            onTransportFailure(e);
645        }
646    }
647
648    protected boolean transportFlush() throws IOException {
649        return true;
650    }
651
652    public void drainInbound() {
653        if (!getServiceState().isStarted() || readSource.isSuspended()) {
654            return;
655        }
656        try {
657            long initial = codec.getReadCounter();
658            // Only process upto 2 x the read buffer worth of data at a time so we can give
659            // other connections a chance to process their requests.
660            while( codec.getReadCounter()-initial < codec.getReadBufferSize()<<2 ) {
661                Object command = codec.read();
662                if ( command!=null ) {
663                    try {
664                        listener.onTransportCommand(command);
665                    } catch (Throwable e) {
666                        e.printStackTrace();
667                        onTransportFailure(new IOException("Transport listener failure."));
668                    }
669
670                    // the transport may be suspended after processing a command.
671                    if (getServiceState() == STOPPED || readSource.isSuspended()) {
672                        return;
673                    }
674                } else {
675                    return;
676                }
677            }
678            yieldSource.merge(1);
679        } catch (IOException e) {
680            onTransportFailure(e);
681        }
682    }
683
684    public SocketAddress getLocalAddress() {
685        return localAddress;
686    }
687
688    public SocketAddress getRemoteAddress() {
689        return remoteAddress;
690    }
691
692    private boolean assertConnected() {
693        try {
694            if ( !isConnected() ) {
695                throw new IOException("Not connected.");
696            }
697            return true;
698        } catch (IOException e) {
699            onTransportFailure(e);
700        }
701        return false;
702    }
703
704    public void suspendRead() {
705        if( isConnected() && readSource!=null ) {
706            readSource.suspend();
707        }
708    }
709
710
711    public void resumeRead() {
712        if( isConnected() && readSource!=null ) {
713            if( rateLimitingChannel!=null ) {
714                rateLimitingChannel.resumeRead();
715            } else {
716                _resumeRead();
717            }
718        }
719    }
720
721    private void _resumeRead() {
722        readSource.resume();
723        dispatchQueue.execute(new Task(){
724            public void run() {
725                drainInbound();
726            }
727        });
728    }
729
730    protected void suspendWrite() {
731        if( isConnected() && writeSource!=null ) {
732            writeSource.suspend();
733        }
734    }
735
736    protected void resumeWrite() {
737        if( isConnected() && writeSource!=null ) {
738            writeSource.resume();
739        }
740    }
741
742    public TransportListener getTransportListener() {
743        return listener;
744    }
745
746    public void setTransportListener(TransportListener transportListener) {
747        this.listener = transportListener;
748    }
749
750    public ProtocolCodec getProtocolCodec() {
751        return codec;
752    }
753
754    public void setProtocolCodec(ProtocolCodec protocolCodec) throws Exception {
755        this.codec = protocolCodec;
756        if( channel!=null && codec!=null ) {
757            initializeCodec();
758        }
759    }
760
761    public boolean isConnected() {
762        return socketState.is(CONNECTED.class);
763    }
764
765    public boolean isClosed() {
766        return getServiceState() == STOPPED;
767    }
768
769    public boolean isUseLocalHost() {
770        return useLocalHost;
771    }
772
773    /**
774     * Sets whether 'localhost' or the actual local host name should be used to
775     * make local connections. On some operating systems such as Macs its not
776     * possible to connect as the local host name so localhost is better.
777     */
778    public void setUseLocalHost(boolean useLocalHost) {
779        this.useLocalHost = useLocalHost;
780    }
781
782    private void trace(String message) {
783        // TODO:
784    }
785
786    public SocketChannel getSocketChannel() {
787        return channel;
788    }
789
790    public ReadableByteChannel getReadChannel() {
791        if(rateLimitingChannel!=null) {
792            return rateLimitingChannel;
793        } else {
794            return channel;
795        }
796    }
797
798    public WritableByteChannel getWriteChannel() {
799        if(rateLimitingChannel!=null) {
800            return rateLimitingChannel;
801        } else {
802            return channel;
803        }
804    }
805
806    public int getMaxReadRate() {
807        return maxReadRate;
808    }
809
810    public void setMaxReadRate(int maxReadRate) {
811        this.maxReadRate = maxReadRate;
812    }
813
814    public int getMaxWriteRate() {
815        return maxWriteRate;
816    }
817
818    public void setMaxWriteRate(int maxWriteRate) {
819        this.maxWriteRate = maxWriteRate;
820    }
821
822    public int getTrafficClass() {
823        return trafficClass;
824    }
825
826    public void setTrafficClass(int trafficClass) {
827        this.trafficClass = trafficClass;
828    }
829
830    public int getReceiveBufferSize() {
831        return receiveBufferSize;
832    }
833
834    public void setReceiveBufferSize(int receiveBufferSize) {
835        this.receiveBufferSize = receiveBufferSize;
836        if( channel!=null ) {
837            try {
838                channel.socket().setReceiveBufferSize(receiveBufferSize);
839            } catch (SocketException ignore) {
840            }
841        }
842    }
843
844    public int getSendBufferSize() {
845        return sendBufferSize;
846    }
847
848    public void setSendBufferSize(int sendBufferSize) {
849        this.sendBufferSize = sendBufferSize;
850        if( channel!=null ) {
851            try {
852                channel.socket().setReceiveBufferSize(sendBufferSize);
853            } catch (SocketException ignore) {
854            }
855        }
856    }
857
858    public boolean isKeepAlive() {
859        return keepAlive;
860    }
861
862    public void setKeepAlive(boolean keepAlive) {
863        this.keepAlive = keepAlive;
864    }
865
866    public Executor getBlockingExecutor() {
867        return blockingExecutor;
868    }
869
870    public void setBlockingExecutor(Executor blockingExecutor) {
871        this.blockingExecutor = blockingExecutor;
872    }
873
874    public boolean isCloseOnCancel() {
875        return closeOnCancel;
876    }
877
878    public void setCloseOnCancel(boolean closeOnCancel) {
879        this.closeOnCancel = closeOnCancel;
880    }
881}