package sila_java.library.manager.server_management;

import com.google.common.net.HostAndPort;
import io.grpc.ManagedChannel;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import sila2.org.silastandard.core.silaservice.v1.SiLAServiceGrpc;
import sila_java.library.core.asynchronous.TaskScheduler;
import sila_java.library.manager.SiLAManager;
import sila_java.library.manager.models.Server;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.HashSet;
import java.util.Set;
import java.util.UUID;

import static sila_java.library.manager.server_management.ServerLoading.getServerId;

/**
 * Pings all SiLA Servers in cache and monitors the online offline status if valid
 */
@Slf4j
public class ServerHeartbeat {
    private static final int CONNECT_TIMEOUT = 2000; // [ms] When a server is assumed to be down
    private final SiLAManager siLAManager;
    private final TaskScheduler pingScheduler;
    private final DiscoveryListener discoveryListener;

    /**
     * Constructor
     * @param siLAManager SiLA Manager managing SiLA Servers
     * @param discoveryListener DiscoveryListener to validate cache
     * @param period time in milliseconds between successive heartbeats.
     */
    public ServerHeartbeat(
            @NonNull final SiLAManager siLAManager,
            @NonNull final DiscoveryListener discoveryListener,
            final int period
    ) {
        this.siLAManager = siLAManager;
        this.discoveryListener = discoveryListener;
        this.pingScheduler = new TaskScheduler(this::checkAll, period, "Ping Scheduler");

        Runtime.getRuntime().addShutdownHook(new Thread(this::stop));
    }

    public void start() {
        this.pingScheduler.start();
    }

    public void stop() {
        this.pingScheduler.stop();
    }

    /**
     * Pinging SiLA Servers & Update Manager
     */
    private void checkAll() {
        final Set<HostAndPort> onlineSockets = new HashSet<>();

        siLAManager.getSiLAServers().forEach((key, server) -> {
            if (!server.getStatus().equals(Server.Status.INVALID)) {
                Server.Status status = Server.Status.OFFLINE;

                // Server is only online if socket not already taken,
                // connection is up and UUID is the same.
                final HostAndPort hostAndPort = HostAndPort.fromParts(server.getHost(), server.getPort());
                // Host and Port is unique from client perspective
                final boolean socketUp = checkConnection(server.getHost(), server.getPort());

                if (!onlineSockets.contains(hostAndPort) && socketUp) {
                    try {
                        final ManagedChannel managedChannel = this.siLAManager
                                .getSilaConnections()
                                .get(key)
                                .getManagedChannel();
                        final UUID uuid = getServerId(SiLAServiceGrpc.newBlockingStub(managedChannel));

                        if (server.getConfiguration().getUuid().equals(uuid)) {
                            status = Server.Status.ONLINE;
                            onlineSockets.add(hostAndPort);
                        }
                    } catch (Exception e) {
                        log.debug("Failing UUID Retrieval {}", e.getMessage());
                    }
                }
                this.siLAManager.setServerStatus(server.getConfiguration().getUuid(), status);
            }
        });
    }

    /**
     * Utility to check Connection of a Socket
     * @param host Host of Socket
     * @param port Port of Socket
     * @return If Socket is up, returns true
     */
    private boolean checkConnection(@NonNull final String host, final int port) {
        final InetSocketAddress endPoint = new InetSocketAddress(host, port);

        if (endPoint.isUnresolved()) {
            log.error("Failure " + endPoint);
            return (false);
        }
        try (final Socket socket = new Socket()) {
            socket.connect(endPoint, CONNECT_TIMEOUT);
            log.debug("Success reaching socket {}", endPoint);
            return (true);
        } catch (final IOException e) {
            log.debug("Failure reaching socket {} message: {} - {}", endPoint , e.getClass().getSimpleName(), e.getMessage());
        }
        return (false);
    }
}