package org.xbib.io.sshd.fs;

import org.xbib.io.sshd.client.subsystem.sftp.SftpClient;
import org.xbib.io.sshd.common.subsystem.sftp.SftpConstants;
import org.xbib.io.sshd.common.subsystem.sftp.SftpException;
import org.xbib.io.sshd.common.util.GenericUtils;
import org.xbib.io.sshd.common.util.ValidateUtils;
import org.xbib.io.sshd.common.util.io.IoUtils;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.AsynchronousCloseException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.nio.channels.FileLock;
import java.nio.channels.OverlappingFileLockException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

/**
 *
 */
public class SftpRemotePathChannel extends FileChannel {
    public static final String COPY_BUFSIZE_PROP = "sftp-channel-copy-buf-size";
    public static final int DEFAULT_TRANSFER_BUFFER_SIZE = IoUtils.DEFAULT_COPY_SIZE;

    public static final Set<SftpClient.OpenMode> READ_MODES =
            Collections.unmodifiableSet(EnumSet.of(SftpClient.OpenMode.Read));

    public static final Set<SftpClient.OpenMode> WRITE_MODES =
            Collections.unmodifiableSet(
                    EnumSet.of(SftpClient.OpenMode.Write, SftpClient.OpenMode.Append, SftpClient.OpenMode.Create, SftpClient.OpenMode.Truncate));

    private final String path;
    private final Collection<SftpClient.OpenMode> modes;
    private final boolean closeOnExit;
    private final SftpClient sftp;
    private final SftpClient.CloseableHandle handle;
    private final Object lock = new Object();
    private final AtomicLong posTracker = new AtomicLong(0L);
    private final AtomicReference<Thread> blockingThreadHolder = new AtomicReference<>(null);

    public SftpRemotePathChannel(String path, SftpClient sftp, boolean closeOnExit, Collection<SftpClient.OpenMode> modes) throws IOException {
        this.path = ValidateUtils.checkNotNullAndNotEmpty(path, "No remote file path specified");
        this.modes = Objects.requireNonNull(modes, "No channel modes specified");
        this.sftp = Objects.requireNonNull(sftp, "No SFTP client instance");
        this.closeOnExit = closeOnExit;
        this.handle = sftp.open(path, modes);
    }

    public String getRemotePath() {
        return path;
    }

    @Override
    public int read(ByteBuffer dst) throws IOException {
        return (int) doRead(Collections.singletonList(dst), -1);
    }

    @Override
    public int read(ByteBuffer dst, long position) throws IOException {
        if (position < 0) {
            throw new IllegalArgumentException("read(" + getRemotePath() + ") illegal position to read from: " + position);
        }
        return (int) doRead(Collections.singletonList(dst), position);
    }

    @Override
    public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
        List<ByteBuffer> buffers = Arrays.asList(dsts).subList(offset, offset + length);
        return doRead(buffers, -1);
    }

    protected long doRead(List<ByteBuffer> buffers, long position) throws IOException {
        ensureOpen(READ_MODES);
        synchronized (lock) {
            boolean completed = false;
            boolean eof = false;
            long curPos = (position >= 0L) ? position : posTracker.get();
            try {
                long totalRead = 0;
                beginBlocking();
                loop:
                for (ByteBuffer buffer : buffers) {
                    while (buffer.remaining() > 0) {
                        ByteBuffer wrap = buffer;
                        if (!buffer.hasArray()) {
                            wrap = ByteBuffer.allocate(Math.min(IoUtils.DEFAULT_COPY_SIZE, buffer.remaining()));
                        }
                        int read = sftp.read(handle, curPos, wrap.array(), wrap.arrayOffset() + wrap.position(), wrap.remaining());
                        if (read > 0) {
                            if (wrap == buffer) {
                                wrap.position(wrap.position() + read);
                            } else {
                                buffer.put(wrap.array(), wrap.arrayOffset(), read);
                            }
                            curPos += read;
                            totalRead += read;
                        } else {
                            eof = read == -1;
                            break loop;
                        }
                    }
                }
                completed = true;
                if (totalRead > 0) {
                    return totalRead;
                }

                if (eof) {
                    return -1;
                } else {
                    return 0;
                }
            } finally {
                if (position < 0L) {
                    posTracker.set(curPos);
                }
                endBlocking(completed);
            }
        }
    }

    @Override
    public int write(ByteBuffer src) throws IOException {
        return (int) doWrite(Collections.singletonList(src), -1);
    }

    @Override
    public int write(ByteBuffer src, long position) throws IOException {
        if (position < 0L) {
            throw new IllegalArgumentException("write(" + getRemotePath() + ") illegal position to write to: " + position);
        }
        return (int) doWrite(Collections.singletonList(src), position);
    }

    @Override
    public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
        List<ByteBuffer> buffers = Arrays.asList(srcs).subList(offset, offset + length);
        return doWrite(buffers, -1);
    }

    protected long doWrite(List<ByteBuffer> buffers, long position) throws IOException {
        ensureOpen(WRITE_MODES);
        synchronized (lock) {
            boolean completed = false;
            long curPos = (position >= 0L) ? position : posTracker.get();
            try {
                long totalWritten = 0L;
                beginBlocking();
                for (ByteBuffer buffer : buffers) {
                    while (buffer.remaining() > 0) {
                        ByteBuffer wrap = buffer;
                        if (!buffer.hasArray()) {
                            wrap = ByteBuffer.allocate(Math.min(IoUtils.DEFAULT_COPY_SIZE, buffer.remaining()));
                            buffer.get(wrap.array(), wrap.arrayOffset(), wrap.remaining());
                        }
                        int written = wrap.remaining();
                        sftp.write(handle, curPos, wrap.array(), wrap.arrayOffset() + wrap.position(), written);
                        if (wrap == buffer) {
                            wrap.position(wrap.position() + written);
                        }
                        curPos += written;
                        totalWritten += written;
                    }
                }
                completed = true;
                return totalWritten;
            } finally {
                if (position < 0L) {
                    posTracker.set(curPos);
                }
                endBlocking(completed);
            }
        }
    }

    @Override
    public long position() throws IOException {
        ensureOpen(Collections.emptySet());
        return posTracker.get();
    }

    @Override
    public FileChannel position(long newPosition) throws IOException {
        if (newPosition < 0L) {
            throw new IllegalArgumentException("position(" + getRemotePath() + ") illegal file channel position: " + newPosition);
        }

        ensureOpen(Collections.emptySet());
        posTracker.set(newPosition);
        return this;
    }

    @Override
    public long size() throws IOException {
        ensureOpen(Collections.emptySet());
        return sftp.stat(handle).getSize();
    }

    @Override
    public FileChannel truncate(long size) throws IOException {
        ensureOpen(Collections.emptySet());
        sftp.setStat(handle, new SftpClient.Attributes().size(size));
        return this;
    }

    @Override
    public void force(boolean metaData) throws IOException {
        ensureOpen(Collections.emptySet());
    }

    @Override
    public long transferTo(long position, long count, WritableByteChannel target) throws IOException {
        if ((position < 0) || (count < 0)) {
            throw new IllegalArgumentException("transferTo(" + getRemotePath() + ") illegal position (" + position + ") or count (" + count + ")");
        }
        ensureOpen(READ_MODES);
        synchronized (lock) {
            boolean completed = false;
            boolean eof = false;
            long curPos = position;
            try {
                beginBlocking();

                int bufSize = (int) Math.min(count, Short.MAX_VALUE + 1);
                byte[] buffer = new byte[bufSize];
                long totalRead = 0L;
                while (totalRead < count) {
                    int read = sftp.read(handle, curPos, buffer, 0, buffer.length);
                    if (read > 0) {
                        ByteBuffer wrap = ByteBuffer.wrap(buffer);
                        while (wrap.remaining() > 0) {
                            target.write(wrap);
                        }
                        curPos += read;
                        totalRead += read;
                    } else {
                        eof = read == -1;
                    }
                }
                completed = true;
                return totalRead > 0 ? totalRead : eof ? -1 : 0;
            } finally {
                endBlocking(completed);
            }
        }
    }

    @Override
    public long transferFrom(ReadableByteChannel src, long position, long count) throws IOException {
        if ((position < 0) || (count < 0)) {
            throw new IllegalArgumentException("transferFrom(" + getRemotePath() + ") illegal position (" + position + ") or count (" + count + ")");
        }
        ensureOpen(WRITE_MODES);

        int copySize = sftp.getClientSession().getIntProperty(COPY_BUFSIZE_PROP, DEFAULT_TRANSFER_BUFFER_SIZE);
        boolean completed = false;
        long curPos = (position >= 0L) ? position : posTracker.get();
        long totalRead = 0L;
        byte[] buffer = new byte[(int) Math.min(copySize, count)];

        synchronized (lock) {
            try {
                beginBlocking();

                while (totalRead < count) {
                    ByteBuffer wrap = ByteBuffer.wrap(buffer, 0, (int) Math.min(buffer.length, count - totalRead));
                    int read = src.read(wrap);
                    if (read > 0) {
                        sftp.write(handle, curPos, buffer, 0, read);
                        curPos += read;
                        totalRead += read;
                    } else {
                        break;
                    }
                }
                completed = true;
                return totalRead;
            } finally {
                endBlocking(completed);
            }
        }
    }

    @Override
    public MappedByteBuffer map(MapMode mode, long position, long size) throws IOException {
        throw new UnsupportedOperationException("map(" + getRemotePath() + ")[" + mode + "," + position + "," + size + "] N/A");
    }

    @Override
    public FileLock lock(long position, long size, boolean shared) throws IOException {
        return tryLock(position, size, shared);
    }

    @Override
    public FileLock tryLock(final long position, final long size, boolean shared) throws IOException {
        ensureOpen(Collections.emptySet());

        try {
            sftp.lock(handle, position, size, 0);
        } catch (SftpException e) {
            if (e.getStatus() == SftpConstants.SSH_FX_LOCK_CONFLICT) {
                throw new OverlappingFileLockException();
            }
            throw e;
        }

        return new FileLock(this, position, size, shared) {
            private final AtomicBoolean valid = new AtomicBoolean(true);

            @Override
            public boolean isValid() {
                return acquiredBy().isOpen() && valid.get();
            }

            @SuppressWarnings("synthetic-access")
            @Override
            public void release() throws IOException {
                if (valid.compareAndSet(true, false)) {
                    sftp.unlock(handle, position, size);
                }
            }
        };
    }

    @Override
    protected void implCloseChannel() throws IOException {
        try {
            final Thread thread = blockingThreadHolder.get();
            if (thread != null) {
                thread.interrupt();
            }
        } finally {
            try {
                handle.close();
            } finally {
                if (closeOnExit) {
                    sftp.close();
                }
            }
        }
    }

    private void beginBlocking() {
        begin();
        blockingThreadHolder.set(Thread.currentThread());
    }

    private void endBlocking(boolean completed) throws AsynchronousCloseException {
        blockingThreadHolder.set(null);
        end(completed);
    }

    /**
     * Checks that the channel is open and that its current mode contains
     * at least one of the required ones
     *
     * @param reqModes The required modes - ignored if {@code null}/empty
     * @throws IOException If channel not open or the required modes are not
     *                     satisfied
     */
    private void ensureOpen(Collection<SftpClient.OpenMode> reqModes) throws IOException {
        if (!isOpen()) {
            throw new ClosedChannelException();
        }

        if (GenericUtils.size(reqModes) > 0) {
            for (SftpClient.OpenMode m : reqModes) {
                if (this.modes.contains(m)) {
                    return;
                }
            }

            throw new IOException("ensureOpen(" + getRemotePath() + ") current channel modes (" + this.modes + ") do contain any of the required: " + reqModes);
        }
    }

    @Override
    public String toString() {
        return getRemotePath();
    }
}
