package org.xbib.helianthus.client;

import static java.util.Objects.requireNonNull;

import org.xbib.helianthus.common.SessionProtocol;
import org.xbib.helianthus.common.util.LruMap;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.text.MessageFormat;
import java.util.EnumSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.locks.StampedLock;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Keeps the recent {@link SessionProtocol} negotiation failures. It is a LRU cache which keeps at most
 * 64k 'host name + port' pairs.
 */
public final class SessionProtocolNegotiationCache {

    private static final Logger logger = Logger.getLogger(SessionProtocolNegotiationCache.class.getName());

    private static final StampedLock lock = new StampedLock();
    private static final Map<String, CacheEntry> cache = new LruMap<String, CacheEntry>(65536) {
        private static final long serialVersionUID = -2506868886873712772L;

        @Override
        protected boolean removeEldestEntry(Entry<String, CacheEntry> eldest) {
            final boolean remove = super.removeEldestEntry(eldest);
            if (remove) {
                logger.log(Level.FINE, MessageFormat.format("[evicted] {0} does not support {1}",
                        eldest.getKey(), eldest.getValue()));
            }
            return remove;
        }
    };

    private SessionProtocolNegotiationCache() {
    }

    /**
     * Returns {@code true} if the specified {@code remoteAddress} is known to have no support for
     * the specified {@link SessionProtocol}.
     */
    public static boolean isUnsupported(SocketAddress remoteAddress, SessionProtocol protocol) {
        final String key = key(remoteAddress);
        final CacheEntry e;
        final long stamp = lock.readLock();
        try {
            e = cache.get(key);
        } finally {
            lock.unlockRead(stamp);
        }
        // Can't tell if it's unsupported
        return e != null && e.isUnsupported(protocol);
    }

    /**
     * Updates the cache with the information that the specified {@code remoteAddress} does not support
     * the specified {@link SessionProtocol}.
     */
    public static void setUnsupported(SocketAddress remoteAddress, SessionProtocol protocol) {
        final String key = key(remoteAddress);
        final CacheEntry e = getOrCreate(key);
        if (e.addUnsupported(protocol)) {
            logger.log(Level.FINE, MessageFormat.format("[updated] {0} does not support {1}", key, protocol), e);
        }
    }

    /**
     * Clears the cache.
     */
    public static void clear() {
        int size;
        long stamp = lock.readLock();
        try {
            size = cache.size();
            if (size == 0) {
                return;
            }
            stamp = convertToWriteLock(stamp);
            size = cache.size();
            cache.clear();
        } finally {
            lock.unlock(stamp);
        }
        if (size != 0 && logger.isLoggable(Level.FINE)) {
            if (size != 1) {
                logger.fine(MessageFormat.format("[cleared] {0} entries", size));
            } else {
                logger.fine("Cleared: 1 entry");
            }
        }
    }

    private static CacheEntry getOrCreate(String key) {
        long stamp = lock.readLock();
        try {
            final CacheEntry entry = cache.get(key);
            if (entry != null) {
                return entry;
            }
            stamp = convertToWriteLock(stamp);
            return cache.computeIfAbsent(key, CacheEntry::new);
        } finally {
            lock.unlock(stamp);
        }
    }

    private static String key(SocketAddress remoteAddress) {
        requireNonNull(remoteAddress, "remoteAddress");
        if (!(remoteAddress instanceof InetSocketAddress)) {
            throw new IllegalArgumentException("remoteAddress: " + remoteAddress +
                            " (expected: an " + InetSocketAddress.class.getSimpleName() + ')');
        }
        final InetSocketAddress raddr = (InetSocketAddress) remoteAddress;
        final String host = raddr.getHostString();
        return host + ':' + raddr.getPort();
    }

    private static long convertToWriteLock(long stamp) {
        final long writeStamp = lock.tryConvertToWriteLock(stamp);
        if (writeStamp == 0L) {
            lock.unlockRead(stamp);
            stamp = lock.writeLock();
        } else {
            stamp = writeStamp;
        }
        return stamp;
    }

    private static final class CacheEntry {
        private volatile EnumSet<SessionProtocol> unsupported = EnumSet.noneOf(SessionProtocol.class);

        CacheEntry(String key) {
            // Key is unused. It's just here to simplify the Map.computeIfAbsent() call in getOrCreate().
        }

        boolean addUnsupported(SessionProtocol protocol) {
            EnumSet<SessionProtocol> unsupported = this.unsupported;
            if (unsupported.contains(protocol)) {
                return false;
            }
            final EnumSet<SessionProtocol> copy = EnumSet.copyOf(unsupported);
            copy.add(protocol);
            this.unsupported = copy;
            return true;
        }

        boolean isUnsupported(SessionProtocol protocol) {
            requireNonNull(protocol, "protocol");
            return unsupported.contains(protocol);
        }

        @Override
        public String toString() {
            return unsupported.toString();
        }
    }
}
