package cn.tworice.auth.service.impl;

import cn.tworice.auth.config.AuthProperties;
import cn.tworice.auth.service.AuthManager;
import cn.tworice.common.framework.mail.core.MailExecutor;
import cn.tworice.common.util.StringUtils;
import cn.tworice.common.vo.RequestResult;
import cn.tworice.common.vo.StateCodeConst;
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;

/**
 * 身份验证相关服务
 **/
@Service
@Slf4j
public class DefaultAuthManager implements AuthManager {

    @Resource
    private AuthProperties authProperties;

    @Resource
    private MailExecutor mailExecutor;

    // 存储各IP的失败登录时间戳（线程安全）
    private final ConcurrentHashMap<String, CopyOnWriteArrayList<Long>> ipFailureRecords = new ConcurrentHashMap<>();
    // 封禁IP集合（线程安全）
    private final ConcurrentHashMap.KeySetView<String, Boolean> bannedIPs = ConcurrentHashMap.newKeySet();

    @Override
    public boolean auth(HttpServletRequest request, HttpServletResponse response) {
        if (!authProperties.getBlast()) {
            return true;
        }

        String clientIP = getClientIP(request);

        // 1. 检查IP是否已被封禁
        if (bannedIPs.contains(clientIP)) {
            returnJson(response, "该IP已被封禁，禁止登录");
            return false;
        }

        // 2. 清理该IP的过期失败记录
        cleanExpiredRecords(clientIP);

        // 3. 检查失败次数是否达到阈值
        int failureCount = ipFailureRecords.getOrDefault(clientIP, new CopyOnWriteArrayList<>()).size();
        if (failureCount >= authProperties.getBlastCount()) {
            // 封禁IP并清空记录
            bannedIPs.add(clientIP);
            ipFailureRecords.remove(clientIP);
            // 发送报警邮件
            if(!StringUtils.isEmpty(authProperties.getMailBlast())){
                mailExecutor.sendMail(authProperties.getMailBlast(), "检测到恶意登录请求", "IP地址：" + clientIP + "，系统已自动封禁");
            }

            returnJson(response, "登录失败次数过多，IP已被封禁");
            return false;
        }

        // 4. 正常放行登录请求
        return true;
    }

    @Override
    public void record(HttpServletRequest request) {
        if (!authProperties.getBlast()) {
            return;
        }

        String clientIP = getClientIP(request);
        long currentTime = System.currentTimeMillis();

        // 原子性更新失败记录
        ipFailureRecords.computeIfAbsent(clientIP, k -> new CopyOnWriteArrayList<>())
                .add(currentTime);
    }

    /**
     * 清理指定IP的过期（超过1分钟）失败记录
     */
    private void cleanExpiredRecords(String clientIP) {
        CopyOnWriteArrayList<Long> records = ipFailureRecords.get(clientIP);
        if (records == null) return;

        long threshold = System.currentTimeMillis() - 60_000;
        records.removeIf(timestamp -> timestamp < threshold);

        // 清理后若记录为空则移除条目
        if (records.isEmpty()) {
            ipFailureRecords.remove(clientIP);
        }
    }

    /**
     * 获取客户端真实IP（适配代理场景）
     */
    private String getClientIP(HttpServletRequest request) {
        String ip = request.getHeader("X-Forwarded-For");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return ip.split(",")[0].trim();
    }

    /**
     * 返回JSON格式响应
     */
    private void returnJson(HttpServletResponse response, String message) {
        response.setCharacterEncoding("UTF-8");
        response.setContentType("application/json; charset=utf-8");
        RequestResult result = new RequestResult(StateCodeConst.LOGIN_ERROR, message);

        try (PrintWriter writer = response.getWriter()) {
            writer.print(JSON.toJSONString(result));
        } catch (IOException e) {
            // 使用日志框架记录异常
            System.err.println("响应输出异常：" + e.getMessage());
        }
    }
}
