/**
 * Copyright (c) 2016-2019, Bosco.Liao (bosco_liao@126.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package org.iherus.shiro.cache.redis;

import java.io.Serializable;
import java.lang.ref.SoftReference;
import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
import java.util.Iterator;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.shiro.session.Session;
import org.iherus.shiro.util.concurrent.ConcurrentLinkedHashMap.Builder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class MemorySessionCache {

	private static final Logger logger = LoggerFactory.getLogger(MemorySessionCache.class);

	private static final int DEFAULT_INITIAL_CAPACITY = 1000;
	private static final long DEFAULT_MAX_CAPACITY = 2 << 16;

	private int ttl = 1000; // millisecond
	private long capacityThreshold = DEFAULT_MAX_CAPACITY;

	private ConcurrentMap<Serializable, SoftReference<LocalSession>> cache;

	private AtomicBoolean watchdogEnabled = new AtomicBoolean(false);

	private volatile ScheduledExecutorService executorService;

	public long getCapacityThreshold() {
		return capacityThreshold;
	}

	public void setCapacityThreshold(long capacityThreshold) {
		this.capacityThreshold = capacityThreshold;
	}

	public int getTtl() {
		return ttl;
	}

	public void setTtl(int ttl) {
		this.ttl = ttl;
	}

	public void init() {
		Builder<Serializable, SoftReference<LocalSession>> builder = new Builder<Serializable, SoftReference<LocalSession>>();
		builder.initialCapacity(DEFAULT_INITIAL_CAPACITY);
		Optional.of(capacityThreshold).filter(c -> c >= 0).ifPresent(builder::maximumWeightedCapacity);
		this.cache = builder.build();
	}

	private ScheduledExecutorService getExecutorService() {
		if (executorService == null) {
			synchronized (this) {
				if (executorService == null) {
					executorService = Executors.newSingleThreadScheduledExecutor();
				}
			}
		}
		return executorService;
	}

	public Session get(Serializable key) {
		
		if (key == null) return null;

		if (logger.isDebugEnabled()) {
			logger.debug("Getting a session instance with key [{}] from the local cache.", key);
		}

		SoftReference<LocalSession> ref = this.cache.get(key);
		if (ref != null) {
			LocalSession localSession = ref.get();
			if (localSession.isValid()) {
				return localSession.getSession();
			}

			if (!watchdogEnabled.get()) {
				this.cache.remove(key);
			}
		}
		
		return null;
	}

	public Session put(Serializable key, Session session) {
		try {
			if (logger.isDebugEnabled()) {
				logger.debug("Putting a session instance with key [{}] to the local cache.", key);
			}
			return unwrap(this.cache.put(key, wrap(session, getTtl())));
		} finally {

			if (watchdogEnabled.compareAndSet(false, true)) {

				getExecutorService().schedule(() -> {

					if (logger.isDebugEnabled()) {
						logger.debug("Executing a cleanup job for a locally cached session.");
					}

					try {
						Iterator<Entry<Serializable, SoftReference<LocalSession>>> iterator = cache.entrySet()
								.iterator();
						while (iterator.hasNext()) {
							LocalSession localSession = iterator.next().getValue().get();
							if (!localSession.isValid()) {
								iterator.remove();
							}
						}
					} finally {
						watchdogEnabled.compareAndSet(true, false);
					}
					// The delay time should not be too long,
					// otherwise it will cause multiple nodes to consume memory for the same session.
				}, 3000, TimeUnit.MILLISECONDS);
			}
		}

	}

	public Session remove(Serializable key) {
		return unwrap(this.cache.remove(key));
	}

	public void clear() {
		this.cache.clear();
	}

	public int size() {
		return this.cache.size();
	}

	private static SoftReference<LocalSession> wrap(final Session session, final long ttl) {
		LocalDateTime expiredTime = LocalDateTime.now().plus(ttl, ChronoUnit.MILLIS);
		return new SoftReference<LocalSession>(new LocalSession(session, expiredTime));
	}

	private static Session unwrap(SoftReference<LocalSession> ref) {
		return ref == null ? null : ref.get().getSession();
	}

	private static class LocalSession {

		final Session session;
		final LocalDateTime expiredTime;

		LocalSession(Session session, LocalDateTime expiredTime) {
			this.session = session;
			this.expiredTime = expiredTime;
		}

		Session getSession() {
			return session;
		}

		boolean isValid() {
			return LocalDateTime.now().isBefore(this.expiredTime);
		}

	}

}
