/**
 * Copyright (c) 2016-2021, Bosco.Liao (bosco_liao@126.com).
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package org.iherus.shiro.cache.redis.connection.spring;

import org.iherus.shiro.cache.redis.connection.AbstractRedisConnection;
import org.iherus.shiro.cache.redis.connection.BatchOptions;
import org.iherus.shiro.cache.redis.connection.RedisConnection;
import org.iherus.shiro.cache.redis.serializer.StringSerializer;
import org.iherus.shiro.util.RedisVerUtils;
import org.springframework.data.redis.connection.ClusterSlotHashUtil;
import org.springframework.data.redis.connection.RedisClusterConnection;
import org.springframework.data.redis.connection.RedisClusterNode;
import org.springframework.data.redis.connection.ReturnType;
import org.springframework.data.redis.connection.jedis.JedisClusterConnection;
import org.springframework.data.redis.core.Cursor;
import org.springframework.data.redis.core.ScanOptions;

import java.time.Duration;
import java.util.*;
import java.util.Map.Entry;
import java.util.function.Function;

import static org.iherus.shiro.cache.redis.Constant.GETDEL;
import static org.iherus.shiro.cache.redis.Constant.GETSET;
import static org.iherus.shiro.util.Utils.*;

/**
 * Implemention for compatibility with
 * {@link org.springframework.data.redis.connection.RedisConnection}
 *
 * @author Bosco.Liao
 * @since 2.0.0
 */
public class CompatibleRedisConnection extends AbstractRedisConnection implements RedisConnection {

    private final org.springframework.data.redis.connection.RedisConnection nativeRedisConnection;
    private final BatchOptions options;

    private static final Function<byte[], Integer> calculator = ((key) -> ClusterSlotHashUtil.calculateSlot(key));

    public CompatibleRedisConnection(org.springframework.data.redis.connection.RedisConnection redisConnection) {
        this(redisConnection, BatchOptions.defaulted);
    }

    public CompatibleRedisConnection(org.springframework.data.redis.connection.RedisConnection redisConnection,
                                     BatchOptions options) {
        this.nativeRedisConnection = redisConnection;
        this.options = options;
    }

    public BatchOptions getOptions() {
        return options;
    }

    @Override
    public byte[] get(byte[] key) {
        return nativeRedisConnection.get(key);
    }

    @Override
    public byte[] set(byte[] key, byte[] value, Duration expired) {
        byte[] command = StringSerializer.UTF_8.serialize(GETSET);
        byte[] expiration = longToBytes(Duration.ZERO.equals(expired) ? -1l : expired.toMillis());
        return nativeRedisConnection.eval(command, ReturnType.VALUE, 1, key, value, expiration);
    }

    @Override
    public Long mdel(byte[]... keys) {
        if (isEmpty(keys)) return 0L;

        boolean unlink = RedisVerUtils.getServerVersion(() -> {
            return getServerVersion(isClusterConnection(), getNativeConnection().info("Server"));
        }).isSupportUnlink();

        final Function<byte[][], Long> executor = ((batchKeys) -> {

            if (isJedisClusterConnection()) {
                /**
                 * Spring encapsulated methods do not perform as well as JedisCluster's native
                 * methods, so use native methods to rewrite.
                 */
                Object primitive = getNativeConnection().getNativeConnection();

                return (Long) Optional.ofNullable(invokeMethod(primitive, (unlink ? "unlink" : "del"),
                        new Class[]{byte[][].class}, ((Object) batchKeys))).orElse(0L);
            }

            return unlink ? getNativeConnection().unlink(batchKeys) : getNativeConnection().del(batchKeys);
        });

        if (isClusterConnection()) {
            return batchDeleteOnCluster(this.options.getDeleteBatchSize(), keys, executor, calculator);
        }

        return batchDeleteOnStandalone(this.options.getDeleteBatchSize(), keys, executor);
    }

    @Override
    public List<byte[]> mget(byte[]... keys) {
        if (isClusterConnection()) {
            return mgetOnCluster(keys);
        }
        return batchGetOnStandalone(options.getFetchBatchSize(), keys, (batchKeys) -> {
            return getNativeConnection().mGet(batchKeys);
        });
    }

    protected List<byte[]> mgetOnCluster(byte[]... keys) {
        boolean isJedis = isJedisClusterConnection();
        return batchGetOnCluster(this.options.getFetchBatchSize(), keys, (batchKeys) -> {
            if (isJedis) {
                Object primitive = getNativeConnection().getNativeConnection();
                return (List<byte[]>) invokeMethod(primitive, "mget", new Class[]{byte[][].class},
                        ((Object) batchKeys));
            }
            return getNativeConnection().mGet(batchKeys);
        }, calculator);
    }

    @Override
    public byte[] del(byte[] key) {
        byte[] command = StringSerializer.UTF_8.serialize(GETDEL);
        return nativeRedisConnection.eval(command, ReturnType.VALUE, 1, key);
    }

    @Override
    public Set<byte[]> keys(byte[] pattern) {
        if (isClusterConnection()) {
            return getClusterKeys(pattern);
        }
        return getKeys(pattern);
    }

    @Override
    public boolean isClusterConnection() {
        return getNativeConnection() instanceof RedisClusterConnection;
    }

    @Override
    public void close() {
        this.getNativeConnection().close();
    }

    protected boolean isJedisClusterConnection() {
        return getNativeConnection() instanceof JedisClusterConnection;
    }

    public org.springframework.data.redis.connection.RedisConnection getNativeConnection() {
        return nativeRedisConnection;
    }

    protected Set<byte[]> getKeys(byte[] pattern) {
        Set<byte[]> keys = new HashSet<>();
        ScanOptions options = ScanOptions.scanOptions().match(StringSerializer.UTF_8.deserialize(pattern))
                .count(this.options.getScanBatchSize()).build();
        Cursor<byte[]> cursor = getNativeConnection().scan(options);
        while (cursor.hasNext()) {
            keys.add(cursor.next());
        }
        return keys;
    }

    protected Set<byte[]> getClusterKeys(byte[] pattern) {
        return distributionScanKeys((completion) -> {
            Map<RedisClusterNode, Collection<RedisClusterNode>> masterSlaveMap = ((RedisClusterConnection) getNativeConnection())
                    .clusterGetMasterSlaveMap();
            Set<RedisClusterNode> masters = masterSlaveMap.keySet();
            final String patternAsText = StringSerializer.UTF_8.deserialize(pattern);
            masters.forEach(node -> {
                completion.submit(() -> getKeysFromNode(node, patternAsText));
            });

            return masters.size();
        });
    }

    private Set<byte[]> getKeysFromNode(RedisClusterNode node, String pattern) {
        Set<byte[]> keys = new HashSet<>();
        ScanOptions options = ScanOptions.scanOptions().match(pattern).count(this.options.getScanBatchSize()).build();
        Cursor<byte[]> cursor = ((RedisClusterConnection) getNativeConnection()).scan(node, options);
        while (cursor.hasNext()) {
            keys.add(cursor.next());
        }
        return keys;
    }

    private static String getServerVersion(boolean cluster, final Properties properties) {
        if (!cluster) {
            return properties.get("redis_version").toString();
        }
        for (Entry<Object, Object> entry : properties.entrySet()) {
            if (entry.getKey().toString().contains("redis_version")) {
                return entry.getValue().toString();
            }
        }
        return EMPTY_STRING;
    }

}