/**
 * Copyright (c) 2016-2021, Bosco.Liao (bosco_liao@126.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package org.iherus.shiro.cache.redis;

import org.apache.shiro.cache.Cache;
import org.apache.shiro.cache.CacheException;
import org.apache.shiro.session.Session;
import org.iherus.shiro.util.concurrent.ConcurrentReferenceHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.*;

/**
 * MemorySessionCache
 *
 * @author Bosco.Liao
 * @since 2.0.0
 */
public class MemorySessionCache implements Cache<Serializable, Session> {

    private static final Logger logger = LoggerFactory.getLogger(MemorySessionCache.class);

    private static final Duration DEFAULT_ALIVE_TIME = Duration.ofMillis(1000);

    /**
     * The alive time of the session in memory, default is equals to {@literal 1 seconds}.
     */
    private Duration keepAliveTime;

    private final Map<Serializable, ValidationSession> cacheDelegate;

    public MemorySessionCache() {
        this.keepAliveTime = DEFAULT_ALIVE_TIME;
        this.cacheDelegate = new ConcurrentReferenceHashMap<>();
    }

    public Duration getKeepAliveTime() {
        return keepAliveTime;
    }

    public void setKeepAliveTime(Duration keepAliveTime) {
        this.keepAliveTime = Optional.ofNullable(keepAliveTime).orElse(DEFAULT_ALIVE_TIME);
    }

    @Override
    public Session get(Serializable key) throws CacheException {
        if (key == null) return null;

        if (logger.isDebugEnabled()) {
            logger.debug("Gets a session with key [{}] from the memory cache.", key);
        }

        ValidationSession current = cacheDelegate.getOrDefault(key, ValidationSession.NULL_SESSION);

        if (logger.isDebugEnabled() && current == ValidationSession.NULL_SESSION) {
            logger.debug("The session with key [{}] cannot be found in the memory cache.", key);
        }

        return current.getSession();
    }

    @Override
    public Session put(Serializable key, Session session) throws CacheException {
        if (logger.isDebugEnabled()) {
            logger.debug("Puts a session with key [{}] to the memory cache.", key);
        }

        ValidationSession prev = cacheDelegate.put(key, new ValidationSession(session, LocalDateTime.now().plus(keepAliveTime)));
        return prev == null ? null : prev.getSession();
    }

    @Override
    public Session remove(Serializable key) throws CacheException {
        if (logger.isDebugEnabled()) {
            logger.debug("Removes a session with key [{}] from the memory cache.", key);
        }
        ValidationSession prev = cacheDelegate.remove(key);
        return prev == null ? null : prev.getSession();
    }

    @Override
    public void clear() throws CacheException {
        if (logger.isDebugEnabled()) {
            logger.debug("Clears all sessions from the memory cache.");
        }

        cacheDelegate.clear();
    }

    @Override
    public int size() {
        return cacheDelegate.size();
    }

    @Override
    public Set<Serializable> keys() {
        // Discard meaningless implementations.
        return Collections.emptySet();
    }

    @Override
    public Collection<Session> values() {
        // Discard meaningless implementations.
        return Collections.emptySet();
    }

    private static class ValidationSession {

        static final ValidationSession NULL_SESSION = new ValidationSession(null, null);

        private Session session;
        private LocalDateTime expiredTime;

        ValidationSession(Session session, LocalDateTime expiredTime) {
            this.session = session;
            this.expiredTime = expiredTime;
        }

        Session getSession() {
            return (isValid() ? session : null);
        }

        boolean isValid() {
            return session != null && expiredTime != null && LocalDateTime.now().isBefore(expiredTime);
        }

    }

}
