package host.anzo.core.service;

import host.anzo.commons.annotations.startup.Scheduled;
import host.anzo.commons.annotations.startup.StartupComponent;
import host.anzo.commons.model.enums.EFirewallType;
import host.anzo.commons.model.enums.ERestrictionType;
import host.anzo.commons.utils.DateTimeUtils;
import host.anzo.commons.utils.NetworkUtils;
import host.anzo.core.config.FirewallConfig;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.SystemUtils;
import org.jctools.maps.NonBlockingHashMap;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.time.format.DateTimeFormatter;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
 * @author ANZO
 */
@Slf4j(topic = "Firewall")
@StartupComponent("Service")
public class FirewallService {
    @Getter(lazy = true)
    private static final FirewallService instance = new FirewallService();

    private final Map<String, Long> blockedIps = new NonBlockingHashMap<>();
    private final Map<String, Map<String, BurstRateLimiter>> connectionRateLimiters = new NonBlockingHashMap<>();

    private FirewallService() {
        flushSystemFirewall();
    }

    /**
     * @param clazz service class
     * @param ip client IP address
     * @param sourcePort client service connection port (can be null)
     * @param destPort destination service port
     * @param allowedRequestsPerSecond maximum allowed requests per second
     * @param restrictionType restriction, applied to connection exceeded rate limit
     * @return {@code true} if specified IP allowed to connect/request specified service clazz, {@code false} otherwise
     */
    public boolean isAllowedAddress(@NotNull Class<?> clazz, String ip, @Nullable Integer sourcePort, int destPort, double allowedRequestsPerSecond, ERestrictionType restrictionType) {
        return isAllowedAddress(clazz.getSimpleName(), ip, sourcePort, destPort, allowedRequestsPerSecond, restrictionType);
    }

    /**
     * @param className service class name
     * @param ip client IP address
     * @param destPort destination service port
     * @param sourcePort client service connection port (can be null)
     * @param allowedRequestsPerSecond maximum allowed requests per second
     * @param restrictionType restriction, applied to connection exceeded rate limit
     * @return {@code true} if specified IP allowed to connect/request specified service clazz, {@code false} otherwise
     */
    public boolean isAllowedAddress(String className, String ip, @Nullable Integer sourcePort, int destPort, double allowedRequestsPerSecond, ERestrictionType restrictionType) {
        try {
            if (NetworkUtils.isLocalAddress(InetAddress.getByName(ip))) {
                return true;
            }
        }
        catch (UnknownHostException ignored) {
        }

        if (blockedIps.containsKey(ip)) {
            return false;
        }

        // Add source port to IP to separate client connections
        String rateLimitKey = ip;
        if (sourcePort != null) {
            rateLimitKey = ip + ":" + sourcePort;
        }

        final Map<String, BurstRateLimiter> classLimiters = connectionRateLimiters.computeIfAbsent(className, k -> new NonBlockingHashMap<>());
        final BurstRateLimiter ipLimiter = classLimiters.computeIfAbsent(rateLimitKey, k -> new BurstRateLimiter(allowedRequestsPerSecond, allowedRequestsPerSecond));
        // Check connection per second rule
        if (!ipLimiter.tryAcquire()) {
            if (restrictionType == ERestrictionType.BAN) {
                addBlock(className, ip, sourcePort, destPort, FirewallConfig.FIREWALL_BAN_TIME, TimeUnit.MILLISECONDS);
            }
            classLimiters.remove(ip);
            return false;
        }
        return true;
    }

    /**
     * Add firewall block rule for specified parameters
     * @param className service class name
     * @param ip client IP address
     * @param sourcePort client service connection port (can be null, using for logs)
     * @param destPort destination service port
     * @param banTime ban time in specified time unit's
     * @param banTimeUnit ban time unit's
     */
    public void addBlock(String className, String ip, @Nullable Integer sourcePort, int destPort, long banTime, TimeUnit banTimeUnit) {
        long unbanTime = 0;
        if (FirewallConfig.FIREWALL_TYPE == EFirewallType.SYSTEM && SystemUtils.IS_OS_LINUX) {
            final String firewallCommand = FirewallConfig.FIREWALL_SYSTEM_FIREWALL_RULE.replace("$ip", ip);
            try {
                Runtime.getRuntime().exec(firewallCommand.split(" "));
                unbanTime = -1;
            } catch (IOException e) {
                log.error("Error while adding firewall rule for class=[{}] and ipAddress=[{}] (sourcePort=[{}])", className, ip, sourcePort == null ? "N/A" : sourcePort, e);
            }
        }
        else if (FirewallConfig.FIREWALL_TYPE == EFirewallType.INTERNAL) {
            // Block address internally for a specified time
            unbanTime = System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(banTime, banTimeUnit);
            blockedIps.put(ip, unbanTime);
        }

        if (unbanTime != 0) {
            final String sourcePortInfo = sourcePort == null ? "" : " (sourcePort=" + sourcePort + ")";
            if (unbanTime > 0) {
                log.warn("Address ip=[{}]{} blocked by [{}] firewall at port [{}] until [{}]", ip, sourcePortInfo, className, destPort, DateTimeUtils.getLocalDateTime(unbanTime).format(DateTimeFormatter.ISO_LOCAL_DATE_TIME));
            } else {
                log.warn("Address ip=[{}]{} blocked by [{}] firewall at port [{}] permanently", ip, sourcePortInfo, className, destPort);
            }
        }
    }

    /**
     * Remove specified IP address from internal firewall
     * @param ipAddress IP address to remove
     */
    public void removeBlock(String ipAddress) {
        blockedIps.remove(ipAddress);
    }

    /**
     * Clear blocked IP list
     */
    public void clear() {
        blockedIps.clear();
    }

    public void flushSystemFirewall() {
        if (FirewallConfig.FIREWALL_TYPE == EFirewallType.SYSTEM && SystemUtils.IS_OS_LINUX) {
            for (String set : FirewallConfig.FIREWALL_FLUSHED_SETS_BEFORE_START) {
                final String firewallCommand = "nft flush set inet filter " + set;
                try {
                    Runtime.getRuntime().exec(firewallCommand.split(" "));
                } catch (IOException e) {
                    log.error("Error while flushing firewall set [{}]", set, e);
                }
            }
        }
    }

    @SuppressWarnings("unused")
    @Scheduled(period = 1, timeUnit = TimeUnit.MINUTES, runAfterServerStart = true)
    public void cleanupBans() {
        try {
            for (Map.Entry<String, Long> entry : blockedIps.entrySet()) {
                if (entry.getValue() < System.currentTimeMillis()) {
                    blockedIps.remove(entry.getKey());
                }
            }
        }
        catch (Exception e) {
            log.error("Error while cleaning up firewall bans", e);
        }
        finally {
            connectionRateLimiters.clear();
        }
    }

    /**
     * Implementation of a rate limiter with a burst (i.e. a limiter that allows some amount of bursty traffic).
     * The burst is replenished at a constant rate.
     * @author ANZO
     */
    public static class BurstRateLimiter {
        private final double maxBurst;        // Maximum burst (can be 1.5, 14.5, etc.)
        private final double permitsPerMs;    // Rate at which the burst is replenished (permits/ms)
        private double availablePermits;      // Current amount of available permits (fractional)
        private long lastUpdateTime;          // Time of last update (in ms)

        /**
         * Creates a new {@link BurstRateLimiter} with the given rate and burst.
         * @param permitsPerSecond the rate of the limiter (permits/second)
         * @param maxBurst the maximum burst of the limiter
         */
        public BurstRateLimiter(double permitsPerSecond, double maxBurst) {
            if (permitsPerSecond <= 0 || maxBurst <= 0) {
                throw new IllegalArgumentException("Rate and burst must be positive");
            }
            this.maxBurst = maxBurst;
            this.permitsPerMs = permitsPerSecond / 1000.0;
            this.availablePermits = maxBurst;
            this.lastUpdateTime = System.currentTimeMillis();
        }

        /**
         * Tries to acquire 1 permit.
         * @return {@code true} if the permits are acquired, {@code false} otherwise
         */
        public synchronized boolean tryAcquire() {
            return tryAcquire(1.0);
        }

        /**
         * Tries to acquire the given number of permits (fractional).
         * @param permits the number of permits to acquire
         * @return {@code true} if the permits are acquired, {@code false} otherwise
         */
        public synchronized boolean tryAcquire(double permits) {
            if (permits <= 0) {
                throw new IllegalArgumentException("Permits must be positive");
            }

            // Refill the permits before checking
            refillPermits();

            if (availablePermits >= permits) {
                availablePermits -= permits;
                return true;
            }
            return false;
        }

        /**
         * Refills the available permits by calculating the amount of permits that
         * have been replenished since the last update.
         */
        private void refillPermits() {
            long now = System.currentTimeMillis();
            double elapsedMs = now - lastUpdateTime;
            lastUpdateTime = now;

            availablePermits = Math.min(
                    maxBurst,
                    availablePermits + elapsedMs * permitsPerMs
            );
        }
    }
}