/**
 * Copyright (c) 2016-2019, Bosco.Liao (bosco_liao@126.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package org.iherus.shiro.cache.redis;

import java.io.Serializable;
import java.time.Duration;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;

import org.apache.shiro.ShiroException;
import org.apache.shiro.cache.Cache;
import org.apache.shiro.cache.CacheException;
import org.apache.shiro.session.Session;
import org.apache.shiro.session.UnknownSessionException;
import org.apache.shiro.session.mgt.SimpleSession;
import org.apache.shiro.session.mgt.ValidatingSession;
import org.apache.shiro.session.mgt.eis.CachingSessionDAO;
import org.apache.shiro.session.mgt.eis.SessionIdGenerator;
import org.apache.shiro.util.Initializable;
import org.apache.shiro.util.ThreadContext;
import org.iherus.shiro.cache.redis.serializer.DefaultSerializer;
import org.iherus.shiro.util.Md5Utils;
import org.iherus.shiro.util.Utils;

/**
 * Implementation of {@link org.apache.shiro.session.mgt.eis.SessionDAO}, 
 * it based on local cache to alleviate {@literal Redis} access pressure.
 * 
 * @author Bosco.Liao
 * @since 2.0.0
 */
public class RedisSessionDAO extends CachingSessionDAO implements Initializable {

	private static final String CLASS_NAME = RedisSessionDAO.class.getName();
	private static final String MD5_OF_SESSION_ATTR = CLASS_NAME + "_MD5";
	private static final String REFERENCE_TIME_ATTR = CLASS_NAME + "_referenceTime";

	private static final ThreadLocal<Boolean> NEXTS = new ThreadLocal<Boolean>();

	private final MemorySessionCache localCache;
	
	/**
	 * Session ID prefix.
	 * 
	 * @see SessionIdGenerator
	 */
	private String sessionIdPrefix;

	/**
	 * In the absence of substantial changes, this time will be used as a benchmark
	 * to delay updating the session to {@literal Redis}, which may cause the
	 * session to expire early. default: -1 (No Delay)
	 */
	private long updateAboveDelayTimeMillis = -1;

	public RedisSessionDAO() {
		this.localCache = new MemorySessionCache();
	}

	/**
	 * Deprecated: please use {@link #setMemoryCacheMaxCapacity(long)} instead.
	 * 
	 * <p>Default is equals to {@literal 2 << 16 = 131072 }.
	 * 
	 * @param capacity Maximum local cache capacity.
	 */
	@Deprecated
	public RedisSessionDAO setLocalSessionCacheMaxCapacity(long capacity) {
		return setMemoryCacheMaxCapacity(capacity);
	}
	
	/**
	 * Default is equals to {@literal 2 << 16 = 131072 }.
	 * 
	 * @param capacity Maximum local cache capacity.
	 */
	public RedisSessionDAO setMemoryCacheMaxCapacity(long capacity) {
		this.localCache.setCapacityThreshold(capacity);
		return this;
	}

	/**
	 * Default: {@literal 1000 millisecond}, it is recommended to use the default value.
	 * 
	 * @param ttl The lifetime of the locally cached session.
	 */
	public RedisSessionDAO setMemorySessionTtl(int ttl) {
		this.localCache.setTtl(ttl);
		return this;
	}
	
	/**
	 * Deprecated: please use {@link #getUpdateAboveDelayTimeMillis()} instead.
	 */
	@Deprecated
	public int getDelayTimeMillis() {
		return (int) getUpdateAboveDelayTimeMillis();
	}

	/**
	 * Deprecated: please use {@link #setUpdateAboveDelayTimeMillis(long)} instead.
	 * 
	 * <p>Default is equals {@literal -1} (No Delay).
	 * 
	 * @param delayTimeMillis The delay time of the update operation.
	 */
	@Deprecated
	public void setDelayTimeMillis(int delayTimeMillis) {
		setUpdateAboveDelayTimeMillis(delayTimeMillis);
	}
	
	public long getUpdateAboveDelayTimeMillis() {
		return updateAboveDelayTimeMillis;
	}

	/**
	 * Default is equals {@literal -1} (No Delay).
	 * 
	 * @param delayTimeMillis The delay time of the update operation.
	 */
	public void setUpdateAboveDelayTimeMillis(long updateAboveDelayTimeMillis) {
		this.updateAboveDelayTimeMillis = updateAboveDelayTimeMillis;
	}

	private boolean isDelaySyncEnabled() {
		return this.updateAboveDelayTimeMillis > 0;
	}
	
	public String getSessionIdPrefix() {
		return sessionIdPrefix;
	}

	/**
	 * Sets sessionId prefix, must be the same as the sessionId prefix generated by
	 * the generator.
	 * 
	 * @param sessionIdPrefix prefix
	 * @see org.apache.shiro.session.mgt.eis.SessionIdGenerator
	 */
	public void setSessionIdPrefix(String sessionIdPrefix) {
		this.sessionIdPrefix = sessionIdPrefix;
	}

	@SuppressWarnings("rawtypes")
	public static boolean isExpiredCacheAware(Cache cache) {
		return cache instanceof ExpiredCache;
	}

	public static boolean isSimpleSession(Session session) {
		return session instanceof SimpleSession;
	}

	@Override
	public void init() throws ShiroException {
		try {
			localCache.init();
		} catch (Exception e) {
			throw new CacheException("MemorySessionCache cannot be initialized normally", e);
		}
	}

	@SuppressWarnings({ "unchecked", "rawtypes" })
	@Override
	protected void cache(Session session, Serializable sessionId, Cache<Serializable, Session> cache) {
		if (isExpiredCacheAware(cache)) {
			((ExpiredCache) cache).put(sessionId, session, Duration.ofMillis(session.getTimeout()));
			afterCache(sessionId, session);
			return;
		}
		super.cache(session, sessionId, cache);
		afterCache(sessionId, session);
	}

	protected void afterCache(Serializable sessionId, Session session) {
		try {
			this.localCache.put(sessionId, session);
		} catch (Exception e) {
			// ignored
		}
	}

	@Override
	protected void uncache(Session session) {
		if (session == null) {
			return;
		}
		Serializable id = session.getId();
		if (id == null) {
			return;
		}
		Cache<Serializable, Session> cache = this.getActiveSessionsCacheLazy();
		if (cache != null && !isExpiredCacheAware(cache)) {
			cache.remove(id);
		}
	}

	@SuppressWarnings("unchecked")
	protected Cache<Serializable, Session> getActiveSessionsCacheLazy() {
		return (Cache<Serializable, Session>) Utils.invokeMethod(this, "getActiveSessionsCacheLazy");
	}

	@Override
	protected Serializable doCreate(Session session) {
		Serializable sessionId = generateSessionId(session);
		assignSessionId(session, sessionId);
		initializeReferences(session);
		return sessionId;
	}

	protected void initializeReferences(Session session) {
		if (isDelaySyncEnabled() && isSimpleSession(session)) {
			session.setAttribute(MD5_OF_SESSION_ATTR, "");
			session.setAttribute(REFERENCE_TIME_ATTR, session.getLastAccessTime());
		}
	}

	@Override
	public void update(Session session) throws UnknownSessionException {
		try {
			doUpdate(session);
			if (!NEXTS.get()) return;
			if (session instanceof ValidatingSession) {
				if (((ValidatingSession) session).isValid()) {
					cache(session, session.getId());
				} else {
					uncache(session);
				}
			} else {
				cache(session, session.getId());
			}
		} finally {
			NEXTS.remove();
		}
	}

	@Override
	protected void doUpdate(Session session) {

		if (isDelaySyncEnabled() && isSimpleSession(session)) {
			/**
			 * Don't execute update operations within the allowed delay time.
			 */
			String md5 = getMd5((SimpleSession) session);
			Date lastAccessTime = session.getLastAccessTime();
			Date datumTime = (Date) session.getAttribute(REFERENCE_TIME_ATTR);

			if (md5.equals(session.getAttribute(MD5_OF_SESSION_ATTR))
					&& (lastAccessTime.getTime() - datumTime.getTime() < getDelayTimeMillis())) {

				// Don't execute update.
				NEXTS.set(false);
				return;
			}

			// Update references.
			session.setAttribute(MD5_OF_SESSION_ATTR, md5);
			session.setAttribute(REFERENCE_TIME_ATTR, lastAccessTime);
		}

		NEXTS.set(true);
	}

	protected String getMd5(SimpleSession session) {
		SimpleSession sessionCopy = new SimpleSession();
		sessionCopy.setId(session.getId());
		sessionCopy.setStartTimestamp(session.getStartTimestamp());
		sessionCopy.setStopTimestamp(session.getStopTimestamp());
		sessionCopy.setTimeout(session.getTimeout());
		sessionCopy.setExpired(session.isExpired());
		sessionCopy.setHost(session.getHost());
		sessionCopy.setAttributes(new HashMap<Object, Object>(session.getAttributes()));
		/**
		 * Ignore the following values.
		 */
		sessionCopy.setLastAccessTime(null);
		sessionCopy.removeAttribute(MD5_OF_SESSION_ATTR);
		sessionCopy.removeAttribute(REFERENCE_TIME_ATTR);

		byte[] bytes = DefaultSerializer.INSTANCE.serialize(sessionCopy);

		return Md5Utils.getMd5(bytes);
	}

	@Override
	public void delete(Session session) {
		super.uncache(session);
		doDelete(session);
	}

	@Override
	protected void doDelete(Session session) {
		this.localCache.remove(session.getId());
	}

	@Override
	protected Session getCachedSession(Serializable sessionId) {
		Session session = localCache.get(sessionId);
		if (session == null) {
			session = super.getCachedSession(sessionId);
			if (session != null) {
				localCache.put(sessionId, session);
			}
		}
		return session;
	}

	@Override
	protected Session doReadSession(Serializable sessionId) {
		return null; //should never execute because this implementation relies on parent class to access cache, which
        //is where all sessions reside - it is the cache implementation that determines if the
        //cache is memory only or disk-persistent, etc.
	}

	@Override
	public Collection<Session> getActiveSessions() {
		if (!Utils.isBlank(getSessionIdPrefix())) {
			ThreadContext.put(RedisCache.KEY_PREFIX_THREAD, getSessionIdPrefix());
		}
		return super.getActiveSessions();
	}

}
