/**
 * Copyright (c) 2016-2021, Bosco.Liao (bosco_liao@126.com).
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package org.iherus.shiro.cache.redis;

import org.apache.shiro.cache.Cache;
import org.apache.shiro.cache.CacheException;
import org.iherus.shiro.cache.redis.ExpiredCache.Named;
import org.iherus.shiro.cache.redis.connection.MutableDatabase;
import org.iherus.shiro.cache.redis.connection.RedisConnection;
import org.iherus.shiro.util.Md5Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.util.*;
import java.util.stream.Collectors;

import static org.iherus.shiro.cache.redis.Constant.DEFAULT_CACHE_EXPIRATION;
import static org.iherus.shiro.cache.redis.serializer.StringSerializer.UTF_8;
import static org.iherus.shiro.util.Utils.*;

/**
 * <p>缓存接口实现类</p>
 * <p>Description:实现Shiro缓存接口，提供集中式缓存对象的增删查改。</p>
 *
 * @author Bosco.Liao
 * @since 1.0.0
 */
public class RedisCache<K, V> implements Cache<K, V>, ExpiredCache<K, V>, Named {

    private static final Logger logger = LoggerFactory.getLogger(RedisCache.class);

    private final String name;
    private final RedisOperations operations;
    private final String keyPrefix;
    private final Duration expiration;
    private final Optional<Integer> database;

    public RedisCache(String name, RedisOperations operations, String keyPrefix, Duration expiration,
                      Integer database) {
        assertNotBlank(name, "Name must not be blank.");
        assertNotNull(operations, "RedisOperations must not be null.");

        this.name = name;
        this.operations = operations;
        this.keyPrefix = keyPrefix;
        this.expiration = expiration;
        this.database = Optional.ofNullable(database);
    }

    @Override
    public String getName() {
        return this.name;
    }

    public RedisOperations getOperations() {
        return operations;
    }

    public String getKeyPrefix() {
        return keyPrefix;
    }

    public Duration getExpiration() {
        return expiration;
    }

    public Optional<Integer> getDatabase() {
        return database;
    }

    protected void selectDatabase(RedisConnection connection) {
        if (RedisCache.this.database.orElse(-1).intValue() >= 0 && (connection instanceof MutableDatabase)) {
            ((MutableDatabase) connection).setDatabase(RedisCache.this.database.get());

            if (logger.isInfoEnabled()) {
                logger.info("The database has been dynamically set to {} for {} instance.", RedisCache.this.database.get(), connection.getClass().getCanonicalName());
            }
        } else {
            if (logger.isInfoEnabled()) {
                logger.info("Skip dynamically setting the database, current connection type is: {}, database is: {}.",
                        (connection == null ? "Null" : connection.getClass().getSimpleName()), RedisCache.this.database.orElse(null));
            }
        }
    }

    @SuppressWarnings("unchecked")
    @Override
    public V get(K key) throws CacheException {
        if (logger.isDebugEnabled()) {
            logger.debug("Getting object from cache [" + getName() + "] for key [" + key + "]");
        }

        if (key == null) {
            return null;
        }

        try {
            return (V) operations.execute((connection) -> {
                selectDatabase(connection);
                byte[] value = connection.get(getKeyToBytes(key));
                if (logger.isDebugEnabled() && (!isEmpty(value))) {
                    logger.debug("Cache for [" + key + "] is exist, ready to deserialize it.");
                }
                return deserializeValue(value);
            });
        } catch (Exception e) {
            throw new CacheException(e);
        }
    }

    @Override
    public V put(K key, V value) throws CacheException {
        return put(key, value, null);
    }

    @SuppressWarnings("unchecked")
    @Override
    public V put(K key, V value, Duration expired) throws CacheException {
        if (logger.isDebugEnabled()) {
            logger.debug("Putting object in cache [" + getName() + "] for key [" + key + "]");
        }

        if (key == null) {
            return value;
        }

        try {
            return (V) operations.execute((connection) -> {
                selectDatabase(connection);
                Duration expirationToUse = expired != null ? expired : (getExpiration() == null ? DEFAULT_CACHE_EXPIRATION : getExpiration());
                byte[] previous = connection.set(getKeyToBytes(key), serializeValue(value), expirationToUse);
                return deserializeValue(previous);
            });
        } catch (Exception e) {
            throw new CacheException(e);
        }
    }

    @SuppressWarnings("unchecked")
    @Override
    public V remove(K key) throws CacheException {
        if (logger.isDebugEnabled()) {
            logger.debug("Removing object from cache [" + getName() + "] for key [" + key + "]");
        }
        if (key == null) {
            return null;
        }

        try {
            return (V) operations.execute((connection) -> {
                selectDatabase(connection);

                byte[] previous = connection.del(getKeyToBytes(key));

                if (logger.isDebugEnabled() && (!isEmpty(previous))) {
                    logger.debug("Remove key [{}] successfully.", key);
                }
                return deserializeValue(previous);
            });
        } catch (Exception e) {
            throw new CacheException(e);
        }
    }

    @Override
    public void clear() throws CacheException {
        if (logger.isDebugEnabled()) {
            logger.debug("Clear all cached objects.");
        }

        try {
            operations.execute((connection) -> {
                selectDatabase(connection);

                Set<byte[]> allKeys = connection.keys(UTF_8.serialize(getKeyPrefix() + "*"));

                if (logger.isDebugEnabled() && (!allKeys.isEmpty())) {
                    logger.debug("Currently scanning to {} keys, ready to clear.", allKeys.size());
                }

                Long c = connection.mdel(allKeys.toArray(new byte[allKeys.size()][]));

                if (logger.isDebugEnabled() && c != null) {
                    logger.debug("After the cleanup is completed, this task clears {} cache objects.", c);
                }

                return c;
            });

        } catch (Exception e) {
            throw new CacheException(e);
        }
    }

    @Override
    public int size() {
        try {
            return operations.execute((connection) -> {
                selectDatabase(connection);
                return connection.size(UTF_8.serialize(getKeyPrefix() + "*"));
            }).intValue();
        } catch (Exception e) {
            throw new CacheException(e);
        }
    }

    /**
     * Because the final type of the cache instance is {@literal cache<Object, Object>},
     * the return type of the key is a string.
     */
    @SuppressWarnings("unchecked")
    @Override
    public Set<K> keys() {
        try {
            return operations.execute((connection) -> {
                selectDatabase(connection);

                Set<byte[]> allKeys = connection.keys(UTF_8.serialize(getKeyPrefix() + "*"));

                if (logger.isDebugEnabled() && (!allKeys.isEmpty())) {
                    logger.debug("Currently scanning to {} keys.", allKeys.size());
                }

                Set<K> result = allKeys.stream().map(key -> ((K) deserializeKey(key))).collect(Collectors.toSet());
                return Collections.unmodifiableSet(result);
            });
        } catch (Exception e) {
            throw new CacheException(e);
        }
    }

    @Override
    public Collection<V> values() {
        try {
            return operations.execute((connection) -> {
                selectDatabase(connection);

                Set<byte[]> allKeys = connection.keys(UTF_8.serialize(getKeyPrefix() + "*"));

                List<byte[]> allValues = connection.mget(allKeys.toArray(new byte[allKeys.size()][]));

                if (logger.isDebugEnabled() && (!allValues.isEmpty())) {
                    logger.debug("Currently scanning to {} key-values.", allValues.size());
                }

                @SuppressWarnings("unchecked")
                List<V> result = allValues.stream().map(it -> ((V) deserializeValue(it))).collect(Collectors.toList());

                return Collections.unmodifiableList(result);
            });
        } catch (Exception e) {
            throw new CacheException(e);
        }
    }

    /**
     * Gets bytes key with key prefix.
     */
    protected byte[] getKeyToBytes(K key) {
        return getKeyToBytes(key, getKeyPrefix());
    }

    private byte[] getKeyToBytes(K key, String keyPrefix) {
        byte[] keyBytes = {};
        if (key instanceof byte[]) {
            keyBytes = (byte[]) key;
        } else if (key instanceof String) {
            keyBytes = serializeKey(key.toString().replace(keyPrefix, ""));
        } else {
            /**
             * 此方案只适用于：
             * 在鉴权时 [org.apache.shiro.realm.AuthorizingRealm.doGetAuthenticationInfo(token)]，
             * 返回的 [org.apache.shiro.authc.AuthenticationInfo]对象中，接口方法 [getPrincipals]的返回结果必须是：一成不变的。
             *
             * 否则：建议在构建AuthorizingRealm派生类时，重写如下两个方法达到预期效果：
             * 1) AuthorizingRealm.getAuthorizationCacheKey(PrincipalCollection principals);
             * 2) AuthenticatingRealm.getAuthenticationCacheKey(PrincipalCollection principals).
             */
            if (logger.isWarnEnabled()) {
                logger.warn("The current cache key is not byte[] or String type, and the key should be kept unchanged in subsequent operations.");
            }
            byte[] obytes = serializeValue(key);
            String keyToUse = Md5Utils.getMd5(obytes, key.getClass().getSimpleName()) + new String(obytes).hashCode();
            keyBytes = serializeKey(keyToUse);
        }
        return mergeAll(serializeKey(keyPrefix), keyBytes);
    }

    private byte[] serializeKey(String key) {
        return this.operations.getKeySerializer().serialize(key);
    }

    private String deserializeKey(byte[] bytes) {
        if (isEmpty(bytes)) return null;
        return this.operations.getKeySerializer().deserialize(bytes);
    }

    private byte[] serializeValue(Object value) {
        return this.operations.getValueSerializer().serialize(value);
    }

    private Object deserializeValue(byte[] bytes) {
        if (isEmpty(bytes)) return null;
        return this.operations.getValueSerializer().deserialize(bytes);
    }


}