package org.opoo.ootp.client.impl;

import cn.hutool.core.io.IoUtil;
import com.emc.codec.compression.CompressionConstants;
import com.emc.codec.encryption.EncryptionCodec;
import com.emc.codec.encryption.EncryptionConstants;
import com.emc.codec.encryption.EncryptionUtil;
import com.emc.codec.util.CodecUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpEntity;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.InputStreamEntity;
import org.opoo.ootp.client.ExsCodec;
import org.opoo.ootp.client.ExsMetadata;
import org.opoo.ootp.client.ExsMetadataConfigurerProvider;
import org.opoo.ootp.client.KeyProviderManager;
import org.opoo.ootp.client.Metadata;
import org.opoo.ootp.codec.Codec;
import org.opoo.ootp.codec.CodecDecoder;
import org.opoo.ootp.codec.CodecEncoder;
import org.opoo.ootp.codec.binary.BinaryEncodeConstants;
import org.opoo.ootp.codec.encryption.EncryptionException;
import org.opoo.ootp.codec.encryption.EncryptionUtils;
import org.opoo.ootp.codec.encryption.sm4.SM4EncryptionCodec;
import org.opoo.ootp.codec.encryption.sm4.SM4EncryptionConstants;
import org.opoo.ootp.codec.encryption.smx.SMXEncryptionCodec;
import org.opoo.ootp.codec.encryption.smx.SMXEncryptionConstants;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.security.DigestInputStream;
import java.security.DigestOutputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;

@Slf4j
public class ExsCodecImpl implements ExsCodec {
    // 1 MB
    public static final long DEFAULT_DISK_CACHE_SIZE_THRESHOLD = 1048576L;

    private final CodecEncoder encoder;
    private final CodecDecoder decoder;
    private final KeyProviderManager keyProviderManager;
    private final CodecMode mode;
    private long diskCacheSizeThreshold = DEFAULT_DISK_CACHE_SIZE_THRESHOLD;
    private ExsMetadataConfigurerProvider decodingExsMetadataConfigurerProvider;

    /**
     * 构建 Codec，用于加密/解密（编码/解码)。
     * @param keyProviderManager 加密/解密（编码/解码)中用到的密钥信息
     * @param mode 编码模式
     * @param encodeSpecs 编码器集合的描述，主要用于加密（编码）
     */
    public ExsCodecImpl(KeyProviderManager keyProviderManager, CodecMode mode, String... encodeSpecs) {
        this.keyProviderManager = keyProviderManager;
        this.encoder = new CodecEncoder(encodeSpecs);
        // 私钥只有一份，是全局的
        this.decoder = new CodecDecoder().withProperty(SMXEncryptionCodec.PROP_PRIVATE_KEY_PROVIDER, keyProviderManager.getSM2PrivateKeyProvider());
        this.mode = mode;
    }

    /**
     * 构建 Codec 专用于解密。加密的加密器是空的集合。
     * @param keyProviderManager 加密/解密（编码/解码)中用到的密钥信息
     */
    public ExsCodecImpl(KeyProviderManager keyProviderManager) {
        this(keyProviderManager, CodecMode.ALL);
    }

    public CodecEncoder getEncoder() {
        return encoder;
    }

    public CodecDecoder getDecoder() {
        return decoder;
    }

    public KeyProviderManager getKeyProviderManager() {
        return keyProviderManager;
    }

    public CodecMode getMode() {
        return mode;
    }

    public long getDiskCacheSizeThreshold() {
        return diskCacheSizeThreshold;
    }

    public void setDiskCacheSizeThreshold(long diskCacheSizeThreshold) {
        this.diskCacheSizeThreshold = diskCacheSizeThreshold;
    }

    public ExsMetadataConfigurerProvider getDecodingExsMetadataConfigurerProvider() {
        return decodingExsMetadataConfigurerProvider;
    }

    public void setDecodingExsMetadataConfigurerProvider(ExsMetadataConfigurerProvider decodingExsMetadataConfigurerProvider) {
        this.decodingExsMetadataConfigurerProvider = decodingExsMetadataConfigurerProvider;
    }

    public ExsCodecImpl addDecodingExsMetadataConfigurer(String from, Consumer<ExsMetadata> decodingMetadataConfigurer) {
        if (decodingExsMetadataConfigurerProvider == null) {
            decodingExsMetadataConfigurerProvider = new BasicExsMetadataConfigurerProvider();
        }

        if (decodingExsMetadataConfigurerProvider instanceof BasicExsMetadataConfigurerProvider) {
            ((BasicExsMetadataConfigurerProvider) decodingExsMetadataConfigurerProvider).addConfigurer(from, decodingMetadataConfigurer);
        } else {
            throw new UnsupportedOperationException("addDecodingExsMetadataConfigurer()");
        }

        return this;
    }

    protected Map<String,Object> buildEncodeContext(ExsMetadata metadata) {
        // 对于消息发送，就是消息的接收方。
        // 对于文件上传，storage 为 fs 时，to 是 repo 的名称；storage 是默认时，to 时接收方id；storage 为自定义时，to 是 storage_<自定义的storage>
        final String to = metadata.getTo();
        Objects.requireNonNull(to, "加密发送消息或上传文件时，接收方的ID不能为空，可以通过设置 ExsMetadata 中的 to 属性来指定接收方的 ID");

        final Map<String,Object> context = new HashMap<>();
        Optional.ofNullable(keyProviderManager.getSM4KeyProvider(to)).ifPresent(o -> context.put(SM4EncryptionCodec.PROP_KEY_PROVIDER, o));
        Optional.ofNullable(keyProviderManager.getSM2PublicKey(to)).ifPresent(o -> context.put(SMXEncryptionCodec.PROP_PUBLIC_KEY, o));
        Optional.ofNullable(keyProviderManager.getRSAKeyProvider(to)).ifPresent(o -> context.put(EncryptionCodec.PROP_KEY_PROVIDER, o));
        // 无需 private key，此处仅构建加密上下文数据
        return context;
    }

    protected MessageDigest getSM3Digest() {
        try {
            // Codec 时顺便计算哈希值，省去了后续签名再次计算
            return MessageDigest.getInstance("SM3", EncryptionUtils.SECURITY_PROVIDER);
        } catch (NoSuchAlgorithmException e) {
            throw new EncryptionException("获取哈希算法失败", e);
        }
    }

    protected Path createTempFile(ExsMetadata metadata) throws IOException {
        final Long contentLength = metadata.getContentLength();
        // 缓存到磁盘
        if (contentLength != null && contentLength > diskCacheSizeThreshold) {
            final String fileName = Optional.ofNullable(metadata.getUserMetadata(Metadata.META_FILE_NAME)).orElse("entity.bin");
            final Path tempFile = Files.createTempFile("ootp", fileName);
            log.debug("Codec 内容较大：{}, 创建文件缓存：{}", contentLength, tempFile);
            return tempFile;
        }
        return null;
    }

    /**
     * 是否需要编码/加密
     * @param metadata 请求原数据
     * @return 如果需要处理返回 true
     */
    private boolean isRequired(ExsMetadata metadata) {
        if (CodecMode.NONE == mode) {
            log.debug("处理模式为 NONE，不必进行 Codec 处理");
            return false;
        }

        final String required = metadata.getUserMetadata(ExsCodec.META_CODEC_REQUIRED);
        // 如果明确指出不处理，则不必处理
        if (Boolean.FALSE.toString().equalsIgnoreCase(required)) {
            log.debug("当前消息元数据明确指出不必处理，不必进行 Codec 处理");
            return false;
        }

        if (CodecMode.BY_REQUIRED == mode && Boolean.TRUE.toString().equalsIgnoreCase(required)) {
            log.debug("处理模式为 BY_REQUIRED，并且 meta 数据 REQUIRED 为 true，需要进行 Codec 处理");
            return true;
        }

        if (CodecMode.ALL == mode) {
            log.debug("处理模式为 ALL，除了明确不处理的消息外，都要进行 Codec 处理");
            return true;
        }

        return false;
    }

    @Override
    public HttpEntity encode(HttpEntity entity, ExsMetadata metadata) throws IOException {
        if (!isRequired(metadata)) {
            return entity;
        }

        final MessageDigest sm3 = getSM3Digest();
        final Map<String, Object> context = buildEncodeContext(metadata);
        final Path tempFile = createTempFile(metadata);
        byte[] bytes = null;

        try {
            try (final InputStream encodeStream = encoder.getEncodeStream(entity.getContent(), metadata.getUserMetadata(), context);
                 final InputStream digester = new DigestInputStream(encodeStream, sm3)) {
                if (tempFile != null) {
                    Files.copy(digester, tempFile, StandardCopyOption.REPLACE_EXISTING);
                } else {
                    bytes = IoUtil.readBytes(digester, false);
                }
            }
        } catch (EncryptionException | com.emc.codec.encryption.EncryptionException e) {
            log.error("编码/加密出错，目标接入方：{}，原因：{}", metadata.getTo(), e.getMessage(), e);
            throw e;
        } catch (UnsupportedOperationException ex) {
            log.debug("ExsBody 不支持 InputStream 模式，转为 OutputStream 模式：{}", metadata);
            try (final OutputStream outputStream = (tempFile != null) ? Files.newOutputStream(tempFile) : new ByteArrayOutputStream();
                 final OutputStream digester = new DigestOutputStream(outputStream, sm3);
                 final OutputStream encodeStream = encoder.getEncodeStream(digester, metadata.getUserMetadata(), context)) {
                entity.writeTo(encodeStream);
                if (outputStream instanceof ByteArrayOutputStream) {
                    bytes = ((ByteArrayOutputStream) outputStream).toByteArray();
                }
            } catch (EncryptionException | com.emc.codec.encryption.EncryptionException e) {
                log.error("编码/加密出错，目标接入方：{}，原因：{}", metadata.getTo(), e.getMessage(), e);
                throw e;
            }
        }

        metadata.setContentHash(EncryptionUtil.toHexPadded(sm3.digest()));
        if (tempFile != null) {
            final long size = Files.size(tempFile);
            log.debug("编码/加密后使用文件表示：{} - {}", tempFile, size);

            metadata.setContentLength(size);
            final FileCachedInputStream inputStream = new FileCachedInputStream(tempFile);
            return new InputStreamEntity(inputStream, size, ContentType.get(entity));
        } else {
            log.debug("编码/加密后使用 byte[] 表示：{}", bytes.length);

            metadata.setContentLength((long) bytes.length);
            return new ByteArrayEntity(bytes, ContentType.get(entity));
        }
    }

    @Override
    public InputStream encode(InputStream inputStream, ExsMetadata metadata) throws IOException {
        if (!isRequired(metadata)) {
            return inputStream;
        }

        final MessageDigest sm3 = getSM3Digest();
        final Map<String, Object> context = buildEncodeContext(metadata);
        final Path tempFile = createTempFile(metadata);
        byte[] bytes = null;

        try (final InputStream encodeStream = encoder.getEncodeStream(inputStream, metadata.getUserMetadata(), context);
             final InputStream digester = new DigestInputStream(encodeStream, sm3)) {
            if (tempFile != null) {
                Files.copy(digester, tempFile, StandardCopyOption.REPLACE_EXISTING);
            } else {
                bytes = IoUtil.readBytes(digester, false);
            }
        } catch (EncryptionException | com.emc.codec.encryption.EncryptionException e) {
            log.error("编码/加密出错，目标接入方：{}，原因：{}", metadata.getTo(), e.getMessage(), e);
            throw e;
        }

        metadata.setContentHash(EncryptionUtil.toHexPadded(sm3.digest()));
        if (tempFile != null) {
            final long size = Files.size(tempFile);
            log.debug("编码/加密后使用文件表示：{} - {}", tempFile, size);
            metadata.setContentLength(size);
            return new FileCachedInputStream(tempFile);
        } else {
            log.debug("编码/加密后使用 byte[] 表示：{}", bytes.length);
            metadata.setContentLength((long) bytes.length);
            return new ByteArrayInputStream(bytes);
        }
    }

    protected Map<String,Object> buildDecodeContext(ExsMetadata metadata) {
        final String from = metadata.getFrom();
        final Map<String,Object> context = new HashMap<>();
        Optional.ofNullable(keyProviderManager.getSM4KeyProvider(from)).ifPresent(o -> context.put(SM4EncryptionCodec.PROP_KEY_PROVIDER, o));
        Optional.ofNullable(keyProviderManager.getRSAKeyProvider(from)).ifPresent(o -> context.put(EncryptionCodec.PROP_KEY_PROVIDER, o));
        // 加密的私钥只有一套，全局配置，不必写在上下文里
        return context;
    }

    protected void configureDecodingExsMetadata(ExsMetadata metadata) {
        final String from = metadata.getFrom();
        // 判断要不要对 Metadata 进行配置
        Optional.ofNullable(decodingExsMetadataConfigurerProvider).ifPresent(p -> p.configure(from, metadata));
    }

    protected String getUnencodeSize(String transformMode, ExsMetadata metadata) {
        //SMX:SM4/CBC/PKCS7Padding,BIN:base64
        final String encodeType = CodecUtil.getEncodeType(transformMode);
        log.debug("第一个 Codec 的类型：{}", encodeType);

        switch (encodeType) {
            case SMXEncryptionConstants.ENCRYPTION_TYPE:
                return metadata.getUserMetadata(SMXEncryptionConstants.META_ENCRYPTION_UNENC_SIZE);
            case SM4EncryptionConstants.ENCRYPTION_TYPE:
                return metadata.getUserMetadata(SM4EncryptionConstants.META_ENCRYPTION_UNENC_SIZE);
            case BinaryEncodeConstants.ENCODE_TYPE:
                return metadata.getUserMetadata(BinaryEncodeConstants.META_BIN_UNENCODED_SIZE);
            case EncryptionConstants.ENCRYPTION_TYPE:
                return metadata.getUserMetadata(EncryptionConstants.META_ENCRYPTION_UNENC_SIZE);
            case CompressionConstants.COMPRESSION_TYPE:
                return metadata.getUserMetadata(CompressionConstants.META_COMPRESSION_UNCOMP_SIZE);
            default:
                return null;
        }
    }

    @Override
    public HttpEntity decode(HttpEntity entity, ExsMetadata metadata) throws IOException {
        configureDecodingExsMetadata(metadata);

        final String transformMode = metadata.getUserMetadata(Codec.META_TRANSFORM_MODE);
        if (transformMode == null || "".equals(transformMode.trim())) {
            log.debug("{} 为空，不必解密或者解码", Codec.META_TRANSFORM_MODE);
            return entity;
        }

        final Map<String, Object> context = buildDecodeContext(metadata);

        // 作为消息，在进入该方法之前，已经本地缓存化了（byte[] 或者 String）
        // 否则 InputStream 只能在 Response 关闭之前使用，一旦网络链接关闭了，就不能再处理 stream
        try {
            Optional.ofNullable(getUnencodeSize(transformMode, metadata)).ifPresent(size -> metadata.setContentLength(Long.parseLong(size)));
            final InputStream decodeStream = decoder.getDecodeStream(entity.getContent(), metadata.getUserMetadata(), context);
            return new InputStreamEntity(decodeStream, ContentType.get(entity));
        } catch (EncryptionException | com.emc.codec.encryption.EncryptionException e) {
            log.error("解码/解密出错，目标接入方：{}，原因：{}", metadata.getFrom(), e.getMessage(), e);
            throw e;
        } catch (UnsupportedOperationException ex) {
            log.debug("ExsBody 不支持 InputStream 模式，转为 OutputStream 模式：{}", metadata);
            // 尝试缓存
            final Path tempFile = createTempFile(metadata);
            byte[] bytes = null;
            try (final OutputStream outputStream = tempFile != null ? Files.newOutputStream(tempFile) : new ByteArrayOutputStream();
                final OutputStream decodeStream = decoder.getDecodeStream(outputStream, metadata.getUserMetadata(), context)) {
                entity.writeTo(decodeStream);
                if (outputStream instanceof ByteArrayOutputStream) {
                    bytes = ((ByteArrayOutputStream) outputStream).toByteArray();
                }
            } catch (EncryptionException | com.emc.codec.encryption.EncryptionException e) {
                log.error("解码/解密出错，目标接入方：{}，原因：{}", metadata.getFrom(), e.getMessage(), e);
                throw e;
            }
            if (tempFile != null) {
                final FileCachedInputStream inputStream = new FileCachedInputStream(tempFile);
                metadata.setContentLength(Files.size(tempFile));
                return new InputStreamEntity(inputStream, Files.size(tempFile), ContentType.get(entity));
            } else {
                metadata.setContentLength((long) bytes.length);
                return new ByteArrayEntity(bytes, ContentType.get(entity));
            }
        }
    }

    @Override
    public InputStream decode(InputStream inputStream, ExsMetadata metadata) throws IOException {
        configureDecodingExsMetadata(metadata);

        final String transformMode = metadata.getUserMetadata(Codec.META_TRANSFORM_MODE);
        if (transformMode == null || "".equals(transformMode.trim())) {
            log.debug("{} 为空，不必解密或者解码", Codec.META_TRANSFORM_MODE);
            return inputStream;
        }

        Optional.ofNullable(getUnencodeSize(transformMode, metadata)).ifPresent(size -> metadata.setContentLength(Long.parseLong(size)));

        final Map<String, Object> context = buildDecodeContext(metadata);
        return decoder.getDecodeStream(inputStream, metadata.getUserMetadata(), context);
    }

    @Override
    public OutputStream decode(OutputStream outputStream, ExsMetadata metadata) throws IOException {
        configureDecodingExsMetadata(metadata);

        final String transformMode = metadata.getUserMetadata(Codec.META_TRANSFORM_MODE);
        if (transformMode == null || "".equals(transformMode.trim())) {
            log.debug("{} 为空，不必解密或者解码", Codec.META_TRANSFORM_MODE);
            return outputStream;
        }

        Optional.ofNullable(getUnencodeSize(transformMode, metadata)).ifPresent(size -> metadata.setContentLength(Long.parseLong(size)));

        final Map<String, Object> context = buildDecodeContext(metadata);
        return decoder.getDecodeStream(outputStream, metadata.getUserMetadata(), context);
    }

}
