package com.davfx.ninio.snmp;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.Executor;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.davfx.ninio.core.Address;
import com.davfx.ninio.core.Connecter;
import com.davfx.ninio.core.Connection;
import com.davfx.ninio.core.NinioBuilder;
import com.davfx.ninio.core.NinioProvider;
import com.davfx.ninio.core.SendCallback;
import com.davfx.ninio.core.UdpSocket;
import com.davfx.ninio.snmp.dependencies.Dependencies;
import com.davfx.ninio.util.ConfigUtils;
import com.davfx.ninio.util.MemoryCache;
import com.google.common.collect.ImmutableList;
import com.typesafe.config.Config;

public final class SnmpClient implements SnmpConnecter {
	private static final Logger LOGGER = LoggerFactory.getLogger(SnmpClient.class);

	private static final Config CONFIG = ConfigUtils.load(new Dependencies()).getConfig(SnmpClient.class.getPackage().getName());

	public static final int DEFAULT_PORT = 161;
	public static final int DEFAULT_TRAP_PORT = 162;

	private static final int BULK_SIZE = CONFIG.getInt("bulkSize");
	private static final double AUTH_ENGINES_CACHE_DURATION = ConfigUtils.getDuration(CONFIG, "auth.cache");

	public static interface Builder extends NinioBuilder<SnmpConnecter> {
		@Deprecated
		Builder with(Executor executor);

		Builder with(NinioBuilder<Connecter> connecterFactory);
	}
	
	public static Builder builder() {
		return new Builder() {
			private NinioBuilder<Connecter> connecterFactory = UdpSocket.builder();
			
			@Deprecated
			@Override
			public Builder with(Executor executor) {
				return this;
			}
			
			@Override
			public Builder with(NinioBuilder<Connecter> connecterFactory) {
				this.connecterFactory = connecterFactory;
				return this;
			}

			@Override
			public SnmpConnecter create(NinioProvider ninioProvider) {
				return new SnmpClient(ninioProvider.executor(), connecterFactory.create(ninioProvider));
			}
		};
	}
	
	private static final class EncryptionEngineKey {
		public final String authDigestAlgorithm;
		public final String privEncryptionAlgorithm;
		
		public EncryptionEngineKey(String authDigestAlgorithm, String privEncryptionAlgorithm) {
			this.authDigestAlgorithm = authDigestAlgorithm;
			this.privEncryptionAlgorithm = privEncryptionAlgorithm;
		}

		@Override
		public int hashCode() {
			return Objects.hash(authDigestAlgorithm, privEncryptionAlgorithm);
		}

		@Override
		public boolean equals(Object obj) {
			if (this == obj) {
				return true;
			}
			if (obj == null) {
				return false;
			}
			if (!(obj instanceof EncryptionEngineKey)) {
				return false;
			}
			EncryptionEngineKey other = (EncryptionEngineKey) obj;
			return Objects.equals(authDigestAlgorithm, other.authDigestAlgorithm)
				&& Objects.equals(privEncryptionAlgorithm, other.privEncryptionAlgorithm);
		}
	}

	private static final class AuthRemoteEngineKey {
		public final Address address;
		public final Auth auth;
		
		public AuthRemoteEngineKey(Address address, Auth auth) {
			this.address = address;
			this.auth = auth;
		}

		@Override
		public int hashCode() {
			return Objects.hash(address, auth);
		}

		@Override
		public boolean equals(Object obj) {
			if (this == obj) {
				return true;
			}
			if (obj == null) {
				return false;
			}
			if (!(obj instanceof AuthRemoteEngineKey)) {
				return false;
			}
			AuthRemoteEngineKey other = (AuthRemoteEngineKey) obj;
			return Objects.equals(address, other.address)
				&& Objects.equals(auth, other.auth);
		}
	}

	private final Executor executor;
	private final Connecter connecter;
	
	private final InstanceMapper instanceMapper;

	private final RequestIdProvider requestIdProvider = new RequestIdProvider();
	private final MemoryCache<Address, Auth> auths = MemoryCache.<Address, Auth> builder().expireAfterAccess(AUTH_ENGINES_CACHE_DURATION).build();
	private final MemoryCache<AuthRemoteEngineKey, AuthRemoteEnginePendingRequestManager> authRemoteEngines = MemoryCache.<AuthRemoteEngineKey, AuthRemoteEnginePendingRequestManager> builder().expireAfterAccess(AUTH_ENGINES_CACHE_DURATION).build();
	private final MemoryCache<EncryptionEngineKey, EncryptionEngine> encryptionEngines = MemoryCache.<EncryptionEngineKey, EncryptionEngine> builder().expireAfterAccess(AUTH_ENGINES_CACHE_DURATION).build();

	private SnmpClient(Executor executor, Connecter connecter) {
		this.executor = executor;
		this.connecter = connecter;
		instanceMapper = new InstanceMapper(requestIdProvider);
	}
	
	@Override
	public SnmpRequestBuilder request() {
		return new SnmpRequestBuilder() {
			private String community = null;
			private AuthRemoteSpecification authRemoteSpecification = null;
			private Address address;
			private Oid oid;
			private List<SnmpResult> trap = null;
			
			@Override
			public SnmpRequestBuilder community(String community) {
				this.community = community;
				return this;
			}
			@Override
			public SnmpRequestBuilder auth(AuthRemoteSpecification authRemoteSpecification) {
				this.authRemoteSpecification = authRemoteSpecification;
				return this;
			}
			
			private Instance instance = null;
			
			@Override
			public SnmpRequestBuilder build(Address address, Oid oid) {
				this.address = address;
				this.oid = oid;
				return this;
			}
			
			@Override
			public void cancel() {
				// Deprecated
				executor.execute(new Runnable() {
					@Override
					public void run() {
						if (instance != null) {
							instance.cancel();
						}
					}
				});
			}
			
			@Override
			public SnmpRequestBuilder add(Oid oid, String value) {
				if (trap == null) {
					trap = new LinkedList<>();
				}
				trap.add(new SnmpResult(oid, value));
				return this;
			}

			@Override
			public Cancelable call(final SnmpCallType type, final SnmpReceiver r) {
				final Auth auth = (authRemoteSpecification == null) ? null : new Auth(authRemoteSpecification.login, authRemoteSpecification.authPassword, authRemoteSpecification.authDigestAlgorithm, authRemoteSpecification.privPassword, authRemoteSpecification.privEncryptionAlgorithm);;
				final String contextName = (authRemoteSpecification == null) ? null : authRemoteSpecification.contextName;
				final Oid o = oid;
				final Address a = address;
				final String c = community;
				final Iterable<SnmpResult> t = (trap == null) ? null : ImmutableList.copyOf(trap);
				executor.execute(new Runnable() {
					@Override
					public void run() {
						if (instance != null) {
							throw new IllegalStateException();
						}
						
						instance = new Instance(connecter, instanceMapper, o, contextName, a, type, c, t);

						AuthRemoteEnginePendingRequestManager authRemoteEnginePendingRequestManager = null;
						if (auth != null) {
							Auth previousAuth = auths.get(a);
							if (previousAuth != null) {
								if (!previousAuth.equals(auth)) {
									LOGGER.debug("Auth changed ({} -> {}) for {}", previousAuth, auth, a);
								}
							}
							auths.put(a, auth);
							
							EncryptionEngineKey encryptionEngineKey = new EncryptionEngineKey(auth.authDigestAlgorithm, auth.privEncryptionAlgorithm);
							EncryptionEngine encryptionEngine = encryptionEngines.get(encryptionEngineKey);
							if (encryptionEngine == null) {
								encryptionEngine = new EncryptionEngine(auth.authDigestAlgorithm, auth.privEncryptionAlgorithm, AUTH_ENGINES_CACHE_DURATION);
								encryptionEngines.put(encryptionEngineKey, encryptionEngine);
							}

							AuthRemoteEngineKey authRemoteEngineKey = new AuthRemoteEngineKey(a, auth);
							authRemoteEnginePendingRequestManager = authRemoteEngines.get(authRemoteEngineKey);
							if (authRemoteEnginePendingRequestManager == null) {
								authRemoteEnginePendingRequestManager = new AuthRemoteEnginePendingRequestManager(auth, encryptionEngine);
								authRemoteEngines.put(authRemoteEngineKey, authRemoteEnginePendingRequestManager);

								authRemoteEnginePendingRequestManager.discoverIfNecessary(a, connecter);
							}
						}

						instance.receiver = r;
						instance.authRemoteEnginePendingRequestManager = authRemoteEnginePendingRequestManager;
						instance.launch();
					}
				});
				return new Cancelable() {
					@Override
					public void cancel() {
						executor.execute(new Runnable() {
							@Override
							public void run() {
								if (instance != null) {
									instance.cancel();
								}
							}
						});
					}
				};
			}
		};
	}
	@Override
	public void connect(final SnmpConnection callback) {
		connecter.connect(new Connection() {
			@Override
			public void received(final Address address, final ByteBuffer buffer) {
				executor.execute(new Runnable() {
					@Override
					public void run() {
						LOGGER.trace("Received SNMP packet, size = {}", buffer.remaining());
						int instanceId;
						int errorStatus;
						int errorIndex;
						Iterable<SnmpResult> results;

						Auth auth = auths.get(address);

						AuthRemoteEnginePendingRequestManager authRemoteEnginePendingRequestManager;
						if (auth == null) {
							authRemoteEnginePendingRequestManager = null;
						} else {
							AuthRemoteEngineKey authRemoteEngineKey = new AuthRemoteEngineKey(address, auth);
							authRemoteEnginePendingRequestManager = authRemoteEngines.get(authRemoteEngineKey);
						}

						boolean ready;
						if (authRemoteEnginePendingRequestManager != null) {
							ready = authRemoteEnginePendingRequestManager.isReady();
						} else {
							ready = true;
						}
						try {
							SnmpPacketParser parser = new SnmpPacketParser((authRemoteEnginePendingRequestManager == null) ? null : authRemoteEnginePendingRequestManager.engine, buffer);
							instanceId = parser.getRequestId();
							errorStatus = parser.getErrorStatus();
							errorIndex = parser.getErrorIndex();
							results = parser.getResults();
						} catch (Exception e) {
							LOGGER.error("Invalid packet", e);
							return;
						}
						
						if (authRemoteEnginePendingRequestManager != null) {
							if (ready && (errorStatus == BerConstants.ERROR_STATUS_AUTHENTICATION_NOT_SYNCED)) {
								authRemoteEnginePendingRequestManager.reset();
							}

							authRemoteEnginePendingRequestManager.discoverIfNecessary(address, connecter);
							authRemoteEnginePendingRequestManager.sendPendingRequestsIfReady(address, connecter);
						}
						
						instanceMapper.handle(instanceId, errorStatus, errorIndex, results);
					}
				});
			}
			
			@Override
			public void failed(final IOException ioe) {
				executor.execute(new Runnable() {
					@Override
					public void run() {
						instanceMapper.fail(ioe);
					}
				});
				
				if (callback != null) {
					callback.failed(ioe);
				}
			}
			
			@Override
			public void connected(Address address) {
				if (callback != null) {
					callback.connected(address);
				}
			}
			
			@Override
			public void closed() {
				executor.execute(new Runnable() {
					@Override
					public void run() {
						instanceMapper.fail(new IOException("Closed"));
					}
				});
				
				if (callback != null) {
					callback.closed();
				}
			}
		});
	}
	
	@Override
	public void close() {
		executor.execute(new Runnable() {
			@Override
			public void run() {
				instanceMapper.close();
			}
		});
		
		connecter.close();
	}
	
	private static final class AuthRemoteEnginePendingRequestManager {
		public static final class PendingRequest {
			public final SnmpCallType request;
			public final int instanceId;
			public final Oid oid;
			public final String contextName;
//			public final Iterable<SnmpResult> trap;
			public final SendCallback sendCallback;

			public PendingRequest(SnmpCallType request, int instanceId, Oid oid, String contextName, /*Iterable<SnmpResult> trap, */SendCallback sendCallback) {
				this.request = request;
				this.instanceId = instanceId;
				this.oid = oid;
				this.contextName = contextName;
//				this.trap = trap;
				this.sendCallback = sendCallback;
			}
		}
		
		public final AuthRemoteEngine engine;
		public final List<PendingRequest> pendingRequests = new LinkedList<>();
		
		public AuthRemoteEnginePendingRequestManager(Auth auth, EncryptionEngine encryptionEngine) {
			engine = new AuthRemoteEngine(auth, encryptionEngine);
		}
		
		public boolean isReady() {
			return engine.isValid();
		}
		
		public void reset() {
			engine.reset();
		}
		
		public void discoverIfNecessary(Address address, Connecter connector) {
			if (!engine.isValid()) {
				Version3PacketBuilder builder = Version3PacketBuilder.get(engine, null, RequestIdProvider.IGNORE_ID, null);
				ByteBuffer b = builder.getBuffer();
				LOGGER.trace("Writing discover GET v3: #{}, packet size = {}", RequestIdProvider.IGNORE_ID, b.remaining());
				connector.send(address, b, new SendCallback() {
					@Override
					public void sent() {
					}
					@Override
					public void failed(IOException ioe) {
						IOException e = new IOException("Failed to send discover packet", ioe);
						for (PendingRequest r : pendingRequests) {
							r.sendCallback.failed(e);
						}
						pendingRequests.clear();
					}
				});
			}
		}
		
		public void registerPendingRequest(PendingRequest r) {
			pendingRequests.add(r);
		}
		public void clearPendingRequests() {
			pendingRequests.clear();
		}
		
		public void sendPendingRequestsIfReady(Address address, Connecter connector) {
			if (!engine.isValid()) {
				return;
			}
			
			for (PendingRequest r : pendingRequests) {
				switch (r.request) {
					case GET: {
						Version3PacketBuilder builder = Version3PacketBuilder.get(engine, r.contextName, r.instanceId, r.oid);
						ByteBuffer b = builder.getBuffer();
						LOGGER.trace("Writing GET v3: {} #{}, packet size = {}", r.oid, r.instanceId, b.remaining());
						connector.send(address, b, r.sendCallback);
						break;
					}
					case GETNEXT: {
						Version3PacketBuilder builder = Version3PacketBuilder.getNext(engine, r.contextName, r.instanceId, r.oid);
						ByteBuffer b = builder.getBuffer();
						LOGGER.trace("Writing GETNEXT v3: {} #{}, packet size = {}", r.oid, r.instanceId, b.remaining());
						connector.send(address, b, r.sendCallback);
						break;
					}
					case GETBULK: {
						Version3PacketBuilder builder = Version3PacketBuilder.getBulk(engine, r.contextName, r.instanceId, r.oid, BULK_SIZE);
						ByteBuffer b = builder.getBuffer();
						LOGGER.trace("Writing GETBULK v3: {} #{}, packet size = {}", r.oid, r.instanceId, b.remaining());
						connector.send(address, b, r.sendCallback);
						break;
					}
					case TRAP: {
						LOGGER.error("No TRAP possible in v3: {} #{}", r.oid, r.instanceId);
/*
						Version3PacketBuilder builder = Version3PacketBuilder.trap(engine, r.instanceId, r.oid, r.trap);
						ByteBuffer b = builder.getBuffer();
						LOGGER.trace("Writing TRAP v3: {} #{}, packet size = {}", r.oid, r.instanceId, b.remaining());
						connector.send(address, b, r.sendCallback);
*/
						break;
					}
					default:
						break;
				}
			}
			pendingRequests.clear();
		}
	}
	
	private static final class RequestIdProvider {

		private static final Random RANDOM = new SecureRandom();

		private static final int MIN_ID = 1_000;
		private static final int MAX_ID = 2_043_088_696; // Let's do as snmpwalk is doing
		public static final int IGNORE_ID = MAX_ID;
		
		private static int NEXT = MAX_ID;
		
		private static final Object LOCK = new Object();

		public RequestIdProvider() {
		}
		
		public int get() {
			synchronized (LOCK) {
				if (NEXT == MAX_ID) {
					NEXT = MIN_ID + RANDOM.nextInt(MAX_ID - MIN_ID);
				}
				int k = NEXT;
				NEXT++;
				return k;
			}
		}
	}
	
	private static final class InstanceMapper {
		private final RequestIdProvider requestIdProvider;
		private final Map<Integer, Instance> instances = new HashMap<>();
		
		public InstanceMapper(RequestIdProvider requestIdProvider) {
			this.requestIdProvider = requestIdProvider;
		}
		
		public void map(Instance instance) {
			instances.remove(instance.instanceId);
			
			int instanceId = requestIdProvider.get();

			if (instances.containsKey(instanceId)) {
				LOGGER.warn("The maximum number of simultaneous request has been reached");
				return;
			}
			
			instances.put(instanceId, instance);
			
			LOGGER.trace("New instance ID = {}", instanceId);
			instance.instanceId = instanceId;
		}
		
		public void unmap(Instance instance) {
			instances.remove(instance.instanceId);
			instance.instanceId = RequestIdProvider.IGNORE_ID;
		}
		
		public void close() {
			for (Instance i : instances.values()) {
				i.close();
			}
			instances.clear();
		}

		public void fail(IOException ioe) {
			for (Instance i : instances.values()) {
				i.fail(ioe);
			}
			instances.clear();
		}

		public void handle(int instanceId, int errorStatus, int errorIndex, Iterable<SnmpResult> results) {
			if (instanceId == Integer.MAX_VALUE) {
				LOGGER.trace("Calling all instances (request ID = {})", Integer.MAX_VALUE);
				List<Instance> l = new LinkedList<>(instances.values());
				instances.clear();
				for (Instance i : l) {
					i.handle(errorStatus, errorIndex, results);
				}
				return;
			}
			
			Instance i = instances.remove(instanceId);
			if (i == null) {
				return;
			}
			i.handle(errorStatus, errorIndex, results);
		}
	}
	
	private static final class Instance {
		private final Connecter connector;
		private final InstanceMapper instanceMapper;
		
		private SnmpReceiver receiver;
		
		private final Oid requestOid;
		private final String requestContextName;
		public int instanceId = RequestIdProvider.IGNORE_ID;

		private final Address address;
		private final String community;
		private final SnmpCallType snmpCallType;
		private AuthRemoteEnginePendingRequestManager authRemoteEnginePendingRequestManager = null;
		
		private final Iterable<SnmpResult> trap;

		public Instance(Connecter connector, InstanceMapper instanceMapper, Oid requestOid, String requestContextName, Address address, SnmpCallType snmpCallType, String community, Iterable<SnmpResult> trap) {
			this.connector = connector;
			this.instanceMapper = instanceMapper;
			
			this.requestOid = requestOid;
			this.requestContextName = requestContextName;
			
			this.address = address;
			this.snmpCallType = snmpCallType;
			this.community = community;
			
			this.trap = trap;
		}
		
		public void launch() {
			if (receiver != null) {
				instanceMapper.map(this);
			}
			write();
		}
		
		public void close() {
			receiver = null;
		}
		
		public void cancel() {
			if (authRemoteEnginePendingRequestManager != null) {
				authRemoteEnginePendingRequestManager.clearPendingRequests();
			}
			instanceMapper.unmap(this);
			receiver = null;
		}

		private void write() {
			SendCallback sendCallback = new SendCallback() {
				@Override
				public void sent() {
				}
				@Override
				public void failed(IOException ioe) {
					fail(ioe);
				}
			};
			
			if (authRemoteEnginePendingRequestManager == null) {
				switch (snmpCallType) { 
					case GET: {
						Version2cPacketBuilder builder = Version2cPacketBuilder.get(community, instanceId, requestOid);
						ByteBuffer b = builder.getBuffer();
						LOGGER.trace("Writing GET: {} #{} ({}), packet size = {}", requestOid, instanceId, community, b.remaining());
						connector.send(address, b, sendCallback);
						break;
					}
					case GETNEXT: {
						Version2cPacketBuilder builder = Version2cPacketBuilder.getNext(community, instanceId, requestOid);
						ByteBuffer b = builder.getBuffer();
						LOGGER.trace("Writing GETNEXT: {} #{} ({}), packet size = {}", requestOid, instanceId, community, b.remaining());
						connector.send(address, b, sendCallback);
						break;
					}
					case GETBULK: {
						Version2cPacketBuilder builder = Version2cPacketBuilder.getBulk(community, instanceId, requestOid, BULK_SIZE);
						ByteBuffer b = builder.getBuffer();
						LOGGER.trace("Writing GETBULK: {} #{} ({}), packet size = {}", requestOid, instanceId, community, b.remaining());
						connector.send(address, b, sendCallback);
						break;
					}
					case TRAP: {
						Version2cPacketBuilder builder = Version2cPacketBuilder.trap(community, instanceId, requestOid, trap);
						ByteBuffer b = builder.getBuffer();
						LOGGER.trace("Writing TRAP: {} #{} ({}), packet size = {}", requestOid, instanceId, community, b.remaining());
						connector.send(address, b, sendCallback);
						break;
					}
					default:
						break;
				}
			} else {
				authRemoteEnginePendingRequestManager.registerPendingRequest(new AuthRemoteEnginePendingRequestManager.PendingRequest(snmpCallType, instanceId, requestOid, requestContextName, /*trap, */sendCallback));
				authRemoteEnginePendingRequestManager.sendPendingRequestsIfReady(address, connector);
			}
		}
	
		public void fail(IOException e) {
			if (receiver != null) {
				receiver.failed(e);
			}
			receiver = null;
		}
		
		private void handle(int errorStatus, int errorIndex, Iterable<SnmpResult> results) {
			if (requestOid == null) {
				return;
			}

			if (errorStatus == BerConstants.ERROR_STATUS_AUTHENTICATION_NOT_SYNCED) {
				fail(new IOException("Authentication engine not synced"));
				return;
			}

			if (errorStatus == BerConstants.ERROR_STATUS_AUTHENTICATION_FAILED) {
				fail(new IOException("Authentication failed"));
				return;
			}
			
			if (errorStatus == BerConstants.ERROR_STATUS_TIMEOUT) {
				fail(new IOException("Timeout"));
				return;
			}

			if (errorStatus != 0) {
				LOGGER.trace("Received error: {}/{}", errorStatus, errorIndex);
			}

			for (SnmpResult r : results) {
				if (r.value == null) {
					continue;
				}
				LOGGER.trace("Addind to results: {}", r);
				if (receiver != null) {
					receiver.received(r);
				}
			}
			if (receiver != null) {
				receiver.finished();
			}
			receiver = null;
		}
	}
}
