package host.anzo.core.service;

import com.google.common.util.concurrent.RateLimiter;
import host.anzo.commons.annotations.startup.Scheduled;
import host.anzo.commons.model.enums.EFirewallType;
import host.anzo.commons.model.enums.ERestrictionType;
import host.anzo.commons.utils.DateTimeUtils;
import host.anzo.commons.utils.NetworkUtils;
import host.anzo.commons.utils.VMUtils;
import host.anzo.core.config.FirewallConfig;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.SystemUtils;

import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.time.format.DateTimeFormatter;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
 * @author ANZO
 */
@Slf4j(topic = "Firewall")
@SuppressWarnings("UnstableApiUsage")
public class FirewallService {
    @Getter(lazy = true)
    private static final FirewallService instance = new FirewallService();

    private final Map<String, Long> blockedIps = new ConcurrentHashMap<>();
    private final Map<String, Map<String, RateLimiter>> connectionRateLimiters = new ConcurrentHashMap<>();

    private FirewallService() {
        flushSystemFirewall();
    }

    /**
     * @param clazz service class
     * @param ip client IP address
     * @param destPort destination service port
     * @param allowedRequestsPerSecond maximum allowed requests per second
     * @param restrictionType restriction, applied to connection exceeded rate limit
     * @return {@code true} if specified IP allowed to connect/request specified service clazz, {@code false} otherwise
     */
    public boolean isAllowedAddress(Class<?> clazz, String ip, int destPort, double allowedRequestsPerSecond, ERestrictionType restrictionType) {
        return isAllowedAddress(clazz.getSimpleName(), ip, destPort, allowedRequestsPerSecond, restrictionType);
    }

    /**
     * @param className service class name
     * @param ip client IP address
     * @param destPort destination service port
     * @param allowedRequestsPerSecond maximum allowed requests per second
     * @param restrictionType restriction, applied to connection exceeded rate limit
     * @return {@code true} if specified IP allowed to connect/request specified service clazz, {@code false} otherwise
     */
    public boolean isAllowedAddress(String className, String ip, int destPort, double allowedRequestsPerSecond, ERestrictionType restrictionType) {
        if (!VMUtils.DEBUG) {
            try {
                if (NetworkUtils.isLocalAddress(InetAddress.getByName(ip))) {
                    return true;
                }
            }
            catch (UnknownHostException ignored) {
            }
        }

        if (blockedIps.containsKey(ip)) {
            return false;
        }

        final Map<String, RateLimiter> classLimiters = connectionRateLimiters.computeIfAbsent(className, k -> new ConcurrentHashMap<>());
        final RateLimiter ipLimiter = classLimiters.computeIfAbsent(ip, k -> RateLimiter.create(allowedRequestsPerSecond));
        // Check connection per second rule
        if (!ipLimiter.tryAcquire()) {
            if (restrictionType == ERestrictionType.BAN) {
                addBlock(className, ip, destPort, FirewallConfig.FIREWALL_BAN_TIME, TimeUnit.MILLISECONDS);
            }
            classLimiters.remove(ip);
            return false;
        }
        return true;
    }

    /**
     * Add firewall block rule for specified parameters
     * @param className service class name
     * @param ip client IP address
     * @param destPort destination service port
     * @param banTime ban time in specified time unit's
     * @param banTimeUnit ban time unit's
     */
    public void addBlock(String className, String ip, int destPort, long banTime, TimeUnit banTimeUnit) {
        long unban_time = 0;
        if (FirewallConfig.FIREWALL_TYPE == EFirewallType.SYSTEM && SystemUtils.IS_OS_LINUX) {
            final String firewallCommand = FirewallConfig.FIREWALL_SYSTEM_FIREWALL_RULE.replace("$ip", ip);
            try {
                Runtime.getRuntime().exec(firewallCommand.split(" "));
                unban_time = -1;
            } catch (IOException e) {
                log.error("Error while adding firewall rule for class=[{}] and ipAddress=[{}]", className, ip, e);
            }
        }
        else if (FirewallConfig.FIREWALL_TYPE == EFirewallType.INTERNAL) {
            // Block address internally for a specified time
            unban_time = System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(banTime, banTimeUnit);
            blockedIps.put(ip, unban_time);
        }

        if (unban_time != 0) {
            if (unban_time > 0) {
                log.error("Address ip=[{}] blocked by [{}] firewall at port [{}] for [{}]", ip, className, destPort, DateTimeUtils.getLocalDateTime(unban_time).format(DateTimeFormatter.ISO_LOCAL_DATE));
            } else {
                log.error("Address ip=[{}] blocked by [{}] firewall at port [{}] permanently", ip, className, destPort);
            }
        }
    }

    /**
     * Remove specified IP address from internal firewall
     * @param ipAddress IP address to remove
     */
    public void removeBlock(String ipAddress) {
        blockedIps.remove(ipAddress);
    }

    /**
     * Clear blocked IP list
     */
    public void clear() {
        blockedIps.clear();
    }

    public void flushSystemFirewall() {
        if (FirewallConfig.FIREWALL_TYPE == EFirewallType.SYSTEM && SystemUtils.IS_OS_LINUX) {
            for (String set : FirewallConfig.FIREWALL_FLUSHED_SETS_BEFORE_START) {
                final String firewallCommand = "nft flush set inet filter " + set;
                try {
                    Runtime.getRuntime().exec(firewallCommand.split(" "));
                } catch (IOException e) {
                    log.error("Error while flushing firewall set [{}]", set, e);
                }
            }
        }
    }

    @SuppressWarnings("unused")
    @Scheduled(period = 1, timeUnit = TimeUnit.MINUTES, runAfterServerStart = true)
    public void cleanupBans() {
        for (Map.Entry<String, Long> entry : blockedIps.entrySet()) {
            if (entry.getValue() < System.currentTimeMillis()) {
                blockedIps.remove(entry.getKey());
            }
        }
        connectionRateLimiters.clear();
    }
}