package top.cenze.utils.http.request;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.convert.Convert;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import io.netty.buffer.ByteBufAllocator;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.net.InetAddress;
import java.net.URI;
import java.net.UnknownHostException;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

@Slf4j
public class GatewayUtil {
    // 多次反向代理后会有多个ip值 的分割符
    private final static String IP_UTILS_FLAG = ",";
    // 未知IP
    private final static String UNKNOWN = "unknown";
    // 本地 IP
    private final static String LOCALHOST_IP = "0:0:0:0:0:0:0:1";
    private final static String LOCALHOST_IP1 = "127.0.0.1";

    /**
     * 获取TOKEN
     * @param exchange
     * @return
     */
    public static String getToken(ServerWebExchange exchange) {
        // 获取请求
        ServerHttpRequest request = exchange.getRequest();

        String token = request.getHeaders().getFirst("TOKEN");
        if (StrUtil.isEmpty(token)) {
            token = getParameter(exchange, "TOKEN");
        }

        return token;
    }

    /**
     * 获取请求参数
     * @param exchange
     * @param name
     * @return
     */
    public static String getParameter(ServerWebExchange exchange, String name) {
        // 获取请求
        ServerHttpRequest request = exchange.getRequest();

        // 获取方法
        HttpMethod method = request.getMethod();

        // POST或PUT方法
        if (HttpMethod.POST.equals(method) || HttpMethod.PUT.equals(method)) {
            String body = resolveBodyFromRequest(request);
            JSONObject json = null;
            if (StrUtil.isNotEmpty(body)) {
                json = JSON.parseObject(body);
            } else {
                json = new JSONObject();
            }
            log.info("getParameter json: {}", json.toJSONString());

            if (ObjectUtil.isNotNull(json) && ObjectUtil.isNotNull(json.get(name))) {
                return Convert.toStr(json.get(name));
            }
        }
        // GET方法
        else if (HttpMethod.GET.equals(method)) {
            return request.getQueryParams().getFirst(name);
        }

        return null;
    }

    /**
     * 设置请求参数
     * 调用此方法后，需要在调用rewriteRequestParams重写request
     * @param exchange
     * @param chain
     * @param name
     * @param value
     * @return
     */
    public static String setParameters(ServerWebExchange exchange, GatewayFilterChain chain, String name, Object value) {
        Map<String, Object> mapParams = new HashMap<>();
        mapParams.put(name, value);

        return setParameters(exchange, chain, mapParams);
    }

    /**
     * 设置请求参数集合
     * 调用此方法后，需要在调用rewriteRequestParams重写request
     * @param exchange
     * @param chain
     * @param mapParams
     * @return
     */
    public static String setParameters(ServerWebExchange exchange, GatewayFilterChain chain, Map<String, Object> mapParams) {
        // 获取请求
        ServerHttpRequest request = exchange.getRequest();
        // 获取方法
        HttpMethod method = request.getMethod();
        // 参数文本类型
        String contentType = request.getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);

        // POST、JSON参数、或表单格式（如上传文件等）
        if (HttpMethod.POST.equals(method) &&
                (MediaType.APPLICATION_FORM_URLENCODED_VALUE.equalsIgnoreCase(contentType)
                        || MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(contentType))) {
            // 获取请求体字符串
            String strBody = resolveBodyFromRequest(request);
            if (StrUtil.isEmpty(strBody)) {
                return null;
            }

            // application/x-www-form-urlencoded
            // 其他上传文件之类的，不做参数处理，因为文件流添加参数，文件原格式就会出问题了
            if (MediaType.APPLICATION_FORM_URLENCODED_VALUE.equalsIgnoreCase(contentType)) {
                // 普通键值对，增加参数
                if (CollectionUtil.isNotEmpty(mapParams)) {
                    StringBuilder sb = new StringBuilder();
                    for (Map.Entry<String, Object> entry : mapParams.entrySet()) {
                        sb.append("&").append(entry.getKey()).append("=").append(entry.getValue());
                    }

                    return strBody + sb.toString();
                }
            }
            // application/json
            else if (MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(contentType)) {
                JSONObject json = null;
                if (StrUtil.isNotEmpty(strBody)) {
                    json = JSON.parseObject(strBody);
                } else {
                    json = new JSONObject();
                }
                log.info("getParameter json: {}", json.toJSONString());

                // 添加/修改参数
                if (CollectionUtil.isNotEmpty(mapParams)) {
                    for (Map.Entry<String, Object> entry : mapParams.entrySet()) {
                        json.put(entry.getKey(), entry.getValue());
                    }

                    return json.toJSONString();
                }
            }
        }
        // GET方法
        else if (HttpMethod.GET.equals(method)) {
            // 获取原参数
            URI uri = request.getURI();
            StringBuilder query = new StringBuilder();
            String originalQuery = uri.getRawQuery();
            if (org.springframework.util.StringUtils.hasText(originalQuery)) {
                query.append(originalQuery);
                if (originalQuery.charAt(originalQuery.length() - 1) != '&') {
                    query.append('&');
                }
            }
            // 添加查询参数
            if (CollectionUtil.isNotEmpty(mapParams)) {
                StringBuilder sb = new StringBuilder();
                for (Map.Entry<String, Object> entry : mapParams.entrySet()) {
                    query.append("&").append(entry.getKey()).append("=").append(entry.getValue());
                }

                return query.toString();
            }
        }

        return null;
    }

    /**
     * 重写请求参数
     * @param exchange
     * @param chain
     * @param strBody
     * @return
     */
    public static Mono<Void> rewriteRequestParams(ServerWebExchange exchange, GatewayFilterChain chain, String strBody) {
        if (StrUtil.isEmpty(strBody)) {
            return chain.filter(exchange);
        }

        // 获取请求
        ServerHttpRequest request = exchange.getRequest();
        // 获取方法
        HttpMethod method = request.getMethod();
        // 参数文本类型
        String contentType = request.getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);
        // 获取原uri
        URI uri = request.getURI();

        // POST、JSON参数、或表单格式（如上传文件等）
        if (HttpMethod.POST.equals(method) &&
                (MediaType.APPLICATION_FORM_URLENCODED_VALUE.equalsIgnoreCase(contentType)
                        || MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(contentType))) {
            //下面的将请求体再次封装写回到request里，传到下一级，否则，由于请求体已被消费，后续的服务将取不到值
            URI newUri = UriComponentsBuilder.fromUri(uri).build(true).toUri();
            request = exchange.getRequest().mutate().uri(newUri).build();
            DataBuffer dataBuffer = stringToDataBuffer(strBody);
            Flux<DataBuffer> bodyFlux = Flux.just(dataBuffer);

            // 定义新的消息头
            HttpHeaders headers = new HttpHeaders();
            headers.putAll(exchange.getRequest().getHeaders());

            // 由于修改了传递参数，需要重新设置CONTENT_LENGTH，长度是字节长度，不是字符串长度
            int length = strBody.getBytes().length;
            headers.remove(HttpHeaders.CONTENT_LENGTH);
            headers.setContentLength(length);

            // 设置CONTENT_TYPE
            if (StrUtil.isNotEmpty(contentType)) {
                headers.set(HttpHeaders.CONTENT_TYPE, contentType);
            }

            // 由于post的body只能订阅一次，由于上面代码中已经订阅过一次body。所以要再次封装请求到request才行，不然会报错请求已经订阅过
            request = new ServerHttpRequestDecorator(request) {
                @Override
                public HttpHeaders getHeaders() {
                    long contentLength = headers.getContentLength();
                    HttpHeaders httpHeaders = new HttpHeaders();
                    httpHeaders.putAll(super.getHeaders());
                    if (contentLength > 0) {
                        httpHeaders.setContentLength(contentLength);
                    } else {
                        // TODO: this causes a 'HTTP/1.1 411 Length Required' on httpbin.org
                        httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                    }
                    return httpHeaders;
                }

                @Override
                public Flux<DataBuffer> getBody() {
                    return bodyFlux;
                }
            };

            //封装request，传给下一级
            request.mutate().header(HttpHeaders.CONTENT_LENGTH, Integer.toString(strBody.length()));
            return chain.filter(exchange.mutate().request(request).build());
        }
        // GET方法
        else if (HttpMethod.GET.equals(method)) {
            // 替换查询参数
            URI newUri = UriComponentsBuilder.fromUri(uri)
                    .replaceQuery(strBody)
                    .build(true)
                    .toUri();

            request = exchange.getRequest().mutate().uri(newUri).build();
            return chain.filter(exchange.mutate().request(request).build());
        }

        return chain.filter(exchange);
    }

    /**
     * 从Flux<DataBuffer>中获取字符串的方法
     * @return 请求体
     */
    private static String resolveBodyFromRequest(ServerHttpRequest request) {
        //获取请求体
        Flux<DataBuffer> body = request.getBody();

        AtomicReference<String> bodyRef = new AtomicReference<>();
        body.subscribe(buffer -> {
            CharBuffer charBuffer = StandardCharsets.UTF_8.decode(buffer.asByteBuffer());
            DataBufferUtils.release(buffer);
            bodyRef.set(charBuffer.toString());
        });

        //获取request body
        return bodyRef.get();
    }

    /**
     * 字符串转DataBuffer
     * @param value
     * @return
     */
    private static DataBuffer stringToDataBuffer(String value) {
        byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
        NettyDataBufferFactory nettyDataBufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT);
        DataBuffer buffer = nettyDataBufferFactory.allocateBuffer(bytes.length);
        buffer.write(bytes);
        return buffer;
    }

    /**
     * 获取访问者真实ip
     * @param exchange
     * @return
     */
    public static String getRemoteIP(ServerWebExchange exchange) {
        // 获取请求
        ServerHttpRequest request = exchange.getRequest();

        return getRemoteIP(request);
    }

    /**
     * 获取访问者真实ip
     * @param request
     * @return
     */
    public static String getRemoteIP(ServerHttpRequest request) {
        String ip = null;
        try {
            // 以下两个获取在k8s中，真实的客户端IP，放到了x-Original-Forwarded-For。而WAF的回源地址放到了 x-Forwarded-For了。
            ip = request.getHeaders().getFirst("X-Original-Forwarded-For");
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeaders().getFirst("X-Forwarded-For");
            }
            // 获取nginx等代理的ip
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeaders().getFirst("x-forwarded-for");
            }
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeaders().getFirst("Proxy-Client-IP");
            }
            if (StrUtil.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeaders().getFirst("WL-Proxy-Client-IP");
            }
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeaders().getFirst("HTTP_CLIENT_IP");
            }
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeaders().getFirst("HTTP_X_FORWARDED_FOR");
            }
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeaders().getFirst("X-Real-IP");
            }

            // 兼容k8s集群获取ip
            if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getRemoteAddress().getAddress().getHostAddress();
                if (LOCALHOST_IP1.equalsIgnoreCase(ip) || LOCALHOST_IP.equalsIgnoreCase(ip)) {
                    //根据网卡取本机配置的IP
                    InetAddress iNet = null;
                    try {
                        iNet = InetAddress.getLocalHost();
                    } catch (UnknownHostException e) {
                        log.error("getClientIp error: ", e);
                    }
                    ip = iNet.getHostAddress();
                }
            }
        } catch (Exception e) {
            log.error("IPUtils ERROR ", e);
        }

        // 使用代理，则获取第一个IP地址
        if (StrUtil.isNotEmpty(ip) && ip.indexOf(IP_UTILS_FLAG) > 0) {
            ip = ip.substring(0, ip.indexOf(IP_UTILS_FLAG));
        }
        return ip;
    }

    /**
     * 获取访问域名
     * @param exchange
     * @return
     */
    public static String getServerName(ServerWebExchange exchange) {
        // 获取请求
        ServerHttpRequest request = exchange.getRequest();

        return getServerName(request);
    }

    /**
     * 获取访问域名
     * @param request
     * @return
     */
    public static String getServerName(ServerHttpRequest request) {
        return request.getURI().getHost();
    }
}
