/**
 * Copyright (c) 2016-2019, Bosco.Liao (bosco_liao@126.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package org.iherus.shiro.cache.redis.connection.spring;

import static org.iherus.shiro.cache.redis.Constant.GETDEL;
import static org.iherus.shiro.cache.redis.Constant.GETSET;
import static org.iherus.shiro.util.Utils.invokeMethod;
import static org.iherus.shiro.util.Utils.isEmpty;
import static org.iherus.shiro.util.Utils.longToBytes;

import java.time.Duration;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.function.Function;

import org.iherus.shiro.cache.redis.connection.AbstractRedisConnection;
import org.iherus.shiro.cache.redis.connection.BatchOptions;
import org.iherus.shiro.cache.redis.connection.RedisConnection;
import org.iherus.shiro.cache.redis.connection.redisson.AbstractRedissonConnection;
import org.iherus.shiro.cache.redis.connection.redisson.RedissonClusterConnection;
import org.iherus.shiro.cache.redis.connection.redisson.RedissonConnection;
import org.iherus.shiro.cache.redis.serializer.StringSerializer;
import org.iherus.shiro.util.RedisVerUtils;
import org.redisson.Redisson;
import org.springframework.data.redis.connection.ClusterSlotHashUtil;
import org.springframework.data.redis.connection.RedisClusterConnection;
import org.springframework.data.redis.connection.RedisClusterNode;
import org.springframework.data.redis.connection.ReturnType;
import org.springframework.data.redis.connection.jedis.JedisClusterConnection;
import org.springframework.data.redis.core.Cursor;
import org.springframework.data.redis.core.ScanOptions;
import org.springframework.util.ClassUtils;

/**
 * Implemention for compatibility with
 * {@link org.springframework.data.redis.connection.RedisConnection}
 * 
 * @author Bosco.Liao
 * @since 2.0.0
 */
public class CompatibleRedisConnection extends AbstractRedisConnection implements RedisConnection {

	// since v2.1.0
	private static final Class<?> _redissonConnectionClass = loadRedissonConnectionClass();
	private final AbstractRedissonConnection redissonDelegate;

	private final org.springframework.data.redis.connection.RedisConnection nativeConnection;
	private final BatchOptions options;

	private static final Function<byte[], Integer> calculator = ((key) -> {
		return ClusterSlotHashUtil.calculateSlot(key);
	});
	
	private static Class<?> loadRedissonConnectionClass() {
		Class<?> clazz = null;
		try {
			clazz = ClassUtils.forName("org.redisson.spring.data.connection.RedissonConnection",
					Thread.currentThread().getContextClassLoader());
		} catch (Throwable e) {
			// ignore
		}
		return clazz;
	}

	public CompatibleRedisConnection(org.springframework.data.redis.connection.RedisConnection connection) {
		this(connection, BatchOptions.defaulted);
	}

	public CompatibleRedisConnection(org.springframework.data.redis.connection.RedisConnection connection,
			BatchOptions options) {
		this.nativeConnection = connection;
		this.options = options;

		this.redissonDelegate = (!isRedissonConnection()) ? null : (isClusterConnection()
						? new RedissonClusterConnection((Redisson) this.nativeConnection.getNativeConnection(), options)
						: new RedissonConnection((Redisson) this.nativeConnection.getNativeConnection(), options));

	}

	public BatchOptions getOptions() {
		return options;
	}

	@Override
	public byte[] get(byte[] key) {
		return this.nativeConnection.get(key);
	}

	@Override
	public byte[] set(byte[] key, byte[] value, Duration expired) {
		if (redissonDelegate != null) {
			return (byte[]) invokeMethod(redissonDelegate, "set",
					new Class[] { byte[].class, byte[].class, Duration.class }, new Object[] { key, value, expired });
		}

		byte[] command = StringSerializer.UTF_8.serialize(GETSET);
		return this.nativeConnection.eval(command, ReturnType.VALUE, 1, key, value, longToBytes(expired.toMillis()));
	}

	@Override
	public Long mdel(byte[]... keys) {
		if (redissonDelegate != null) {
			return (Long) invokeMethod(redissonDelegate, "mdel", byte[][].class, keys);
		}
		
		if (isEmpty(keys)) return 0L;

		boolean unlink = RedisVerUtils.getServerVersion(() -> {
			return getServerVersion(isClusterConnection(), getNativeConnection().info("Server"));
		}).isSupportUnlink();

		final Function<byte[][], Long> executor = ((batchKeys) -> {

			if (isJedisClusterConnection()) {
				/**
				 * Spring encapsulated methods do not perform as well as JedisCluster's native
				 * methods, so use native methods to rewrite.
				 */
				Object primitive = getNativeConnection().getNativeConnection();

				return (Long) Optional.ofNullable(invokeMethod(primitive, (unlink ? "unlink" : "del"),
						new Class[] { byte[][].class }, ((Object) batchKeys))).orElse(0L);
			}

			return unlink ? getNativeConnection().unlink(batchKeys) : getNativeConnection().del(batchKeys);
		});

		if (isClusterConnection()) {
			return batchDeleteOnCluster(this.options.getDeleteBatchSize(), keys, executor, calculator);
		}

		return batchDeleteOnStandalone(this.options.getDeleteBatchSize(), keys, executor);
	}
	
	@SuppressWarnings("unchecked")
	@Override
	public List<byte[]> mget(byte[]... keys) {
		if (redissonDelegate != null) {
			return (List<byte[]>) invokeMethod(redissonDelegate, "mget", byte[][].class, keys);
		}
		if (isClusterConnection()) {
			return mgetOnCluster(keys);
		}
		return batchGetOnStandalone(options.getFetchBatchSize(), keys, (batchKeys) -> {
			return getNativeConnection().mGet(batchKeys);
		});
	}
	
	@SuppressWarnings("unchecked")
	protected List<byte[]> mgetOnCluster(byte[]... keys) {
		boolean isJedis = isJedisClusterConnection();
		return batchGetOnCluster(this.options.getFetchBatchSize(), keys, (batchKeys) -> {
			if (isJedis) {
				Object primitive = getNativeConnection().getNativeConnection();
				return (List<byte[]>) invokeMethod(primitive, "mget", new Class[] { byte[][].class },
						((Object) batchKeys));
			}
			return getNativeConnection().mGet(batchKeys);
		}, calculator);
	}

	@Override
	public byte[] del(byte[] key) {
		if (redissonDelegate != null) {
			return (byte[]) invokeMethod(redissonDelegate, "del", byte[].class, key);
		}
		byte[] command = StringSerializer.UTF_8.serialize(GETDEL);
		return this.nativeConnection.eval(command, ReturnType.VALUE, 1, key);
	}

	@SuppressWarnings("unchecked")
	@Override
	public Set<byte[]> keys(byte[] pattern) {
		if (redissonDelegate != null) {
			return (Set<byte[]>) invokeMethod(redissonDelegate, "keys", byte[].class, pattern);
		}
		if (isClusterConnection()) {
			return getClusterKeys(pattern);
		}
		return getKeys(pattern);
	}

	@Override
	public boolean isClusterConnection() {
		return getNativeConnection() instanceof RedisClusterConnection;
	}

	protected boolean isJedisClusterConnection() {
		return getNativeConnection() instanceof JedisClusterConnection;
	}
	
	protected boolean isRedissonConnection() {
		return _redissonConnectionClass != null
				&& _redissonConnectionClass.isAssignableFrom(getNativeConnection().getClass());
	}

	@Override
	public void close() {
		this.getNativeConnection().close();
	}

	public org.springframework.data.redis.connection.RedisConnection getNativeConnection() {
		return nativeConnection;
	}

	protected AbstractRedissonConnection getRedissonDelegate() {
		return redissonDelegate;
	}

	protected Set<byte[]> getKeys(byte[] pattern) {
		Set<byte[]> keys = new HashSet<byte[]>();
		ScanOptions options = ScanOptions.scanOptions().match(StringSerializer.UTF_8.deserialize(pattern))
				.count(this.options.getScanBatchSize()).build();
		Cursor<byte[]> cursor = getNativeConnection().scan(options);
		while (cursor.hasNext()) {
			keys.add(cursor.next());
		}
		return keys;
	}

	protected Set<byte[]> getClusterKeys(byte[] pattern) {
		return distributionScanKeys((completion) -> {
			Map<RedisClusterNode, Collection<RedisClusterNode>> masterSlaveMap = ((RedisClusterConnection) getNativeConnection())
					.clusterGetMasterSlaveMap();
			Set<RedisClusterNode> masters = masterSlaveMap.keySet();
			final String patternAsText = StringSerializer.UTF_8.deserialize(pattern);
			masters.forEach(node -> {
				completion.submit(() -> {
					return getKeysFromNode(node, patternAsText);
				});
			});

			return masters.size();
		});
	}

	private Set<byte[]> getKeysFromNode(RedisClusterNode node, String pattern) {
		Set<byte[]> keys = new HashSet<byte[]>();
		ScanOptions options = ScanOptions.scanOptions().match(pattern).count(this.options.getScanBatchSize()).build();
		Cursor<byte[]> cursor = ((RedisClusterConnection) getNativeConnection()).scan(node, options);
		while (cursor.hasNext()) {
			keys.add(cursor.next());
		}
		return keys;
	}

	private static String getServerVersion(boolean cluster, final Properties properties) {
		if (!cluster) {
			return properties.get("redis_version").toString();
		}
		for (Entry<Object, Object> entry : properties.entrySet()) {
			if (entry.getKey().toString().contains("redis_version")) {
				return entry.getValue().toString();
			}
		}
		return EMPTY_STRING;
	}

}
