package top.doudou.base.web.filter.xxs;

import lombok.extern.slf4j.Slf4j;
import top.doudou.core.exception.ExceptionUtils;
import top.doudou.core.exception.XssException;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.Charset;

import static top.doudou.core.util.StrUtils.cleanXSS;

/**
 * @Description 跨站请求防范
 * @Author 傻男人 <244191347@qq.com>
 * @Date 2020-09-28 16:02
 * @Version V1.0
 */
@Slf4j
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {

    //判断是否是上传 上传忽略
    boolean isUpData = false;

    public XssHttpServletRequestWrapper(HttpServletRequest servletRequest) {
        super(servletRequest);
        String contentType = servletRequest.getContentType();
        if (null != contentType)
            isUpData = contentType.startsWith("multipart");
    }

    @Override
    public String[] getParameterValues(String parameter) {
        String[] values = super.getParameterValues(parameter);
        if (values == null) {
            return null;
        }
        int count = values.length;
        String[] encodedValues = new String[count];
        for (int i = 0; i < count; i++) {
            if (ValidateUtil.isContainsDefaultXSSForbiddenCharacter(values[i])) {
                throw new XssException("Contains illegal characters[From getParameterValues method]:" + values[i]);
            }
            encodedValues[i] = cleanXSS(values[i]);

        }
        return encodedValues;
    }

    @Override
    public String getParameter(String parameter) {
        String value = super.getParameter(parameter);
        if (ValidateUtil.isContainsDefaultXSSForbiddenCharacter(value)) {
            throw new XssException("Contains illegal characters[From getParameter method]：" + value);
        }
        if (value == null) {
            return null;
        }
        return cleanXSS(value);
    }

    /**
     * 获取request的属性时，做xss过滤
     */
    @Override
    public Object getAttribute(String name) {
        Object value = super.getAttribute(name);
        if (null != value && value instanceof String) {
            if(ValidateUtil.isContainsDefaultXSSForbiddenCharacter(String.valueOf(value))){
              throw  new XssException("From getAttribute->参数包含非法字符：{}",String.valueOf(value));
            }
            value = cleanXSS((String) value);
        }
        return value;
    }

    @Override
    public String getHeader(String name) {

        String value = super.getHeader(name);
        if (value == null)
            return null;
        return cleanXSS(value);
    }


    @Override
    public ServletInputStream getInputStream() throws IOException {
        if (isUpData) {
            return super.getInputStream();
        } else {

            final ByteArrayInputStream bais = new ByteArrayInputStream(inputHandlers(super.getInputStream()).getBytes("utf-8"));

            return new ServletInputStream() {
                @Override
                public boolean isFinished() {
                    return false;
                }

                @Override
                public boolean isReady() {
                    return false;
                }

                @Override
                public void setReadListener(ReadListener readListener) {

                }

                @Override
                public int read() throws IOException {
                    return bais.read();
                }
            };
        }

    }

    public String inputHandlers(ServletInputStream servletInputStream) {
        StringBuilder sb = new StringBuilder();
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new InputStreamReader(servletInputStream, Charset.forName("UTF-8")));
            String line = "";
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            log.error(ExceptionUtils.toString(e));
        } finally {
            if (servletInputStream != null) {
                try {
                    servletInputStream.close();
                } catch (IOException e) {
                    log.error(ExceptionUtils.toString(e));
                }
            }
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    log.error(ExceptionUtils.toString(e));
                }
            }
        }
        if (ValidateUtil.isContainsDefaultXSSForbiddenCharacter(sb.toString())) {
            throw new XssException("Contains illegal characters[From getInputStream method]：" + sb.toString());
        }
        return cleanXSS(sb.toString());
    }

}