package cn.sinozg.applet.common.filter;

import cn.sinozg.applet.common.constant.BaseConstants;
import cn.sinozg.applet.common.constant.HeaderConstants;
import cn.sinozg.applet.common.core.model.AesRsaDecrypt;
import cn.sinozg.applet.common.properties.SignValue;
import cn.sinozg.applet.common.utils.CypherUtil;
import cn.sinozg.applet.common.utils.JsonUtil;
import cn.sinozg.applet.common.utils.WebUtil;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* 构建可重复读取inputStream的request
* @Author: xyb
* @Description:
* @Date: 2022-11-14 下午 09:41
**/
public class WrapperRequest extends HttpServletRequestWrapper {

    private final byte[] body;

    private final SignValue sign;

    private final Logger log = LoggerFactory.getLogger(WrapperRequest.class);

    private final Map<String, String> headerMap = new HashMap<>();

    public WrapperRequest(HttpServletRequest request, ServletResponse response, SignValue sign) throws IOException {
        super(request);
        this.sign = sign;
        request.setCharacterEncoding(BaseConstants.UTF8);
        response.setCharacterEncoding(BaseConstants.UTF8);
        body = decrypt();
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() {
        return new WrapperInputStream(body);
    }

    @Override
    public String getHeader(String name) {
        String value = super.getHeader(name);
        if (headerMap.containsKey(name)) {
            value = headerMap.get(name);
        }
        return value;
    }

    @Override
    public Enumeration<String> getHeaderNames() {
        List<String> names = Collections.list(super.getHeaderNames());
        names.addAll(headerMap.keySet());
        return Collections.enumeration(names);
    }

    @Override
    public Enumeration<String> getHeaders(String name) {
        List<String> values = Collections.list(super.getHeaders(name));
        if (headerMap.containsKey(name)) {
            values.add(headerMap.get(name));
        }
        return Collections.enumeration(values);
    }

    /**
     * 解密请求参数
     * @return 解密后的数据
     * @throws IOException 异常
     */
    private byte[] decrypt() throws IOException {
        String jsonBody = WebUtil.getBodyString(this.getRequest());
        byte[] bytes = null;
        if (sign.isRsaEnable()) {
            AesRsaDecrypt decrypt = JsonUtil.toPojo(jsonBody, AesRsaDecrypt.class);
            if (decrypt == null) {
                throw new IOException("Encryption parameter format error!");
            }
            headerMap.put(HeaderConstants.X_PUB_KEY, decrypt.getPublicKey());
            try {
                bytes = CypherUtil.decryptJson(decrypt.getData(), sign.getPrivateKey(), decrypt.getAesKey());
            } catch (Exception e) {
                log.error("解密参数错误！", e);
            }
            if (bytes == null) {
                throw new IOException("Decryption parameter error!");
            }
        } else {
            bytes = jsonBody.getBytes(StandardCharsets.UTF_8);
        }
        return bytes;
    }
}
