package app.hypi.mekadb;

import app.hypi.mekadb.client.*;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.stub.StreamObserver;
import org.joda.time.DateTime;
import org.joda.time.Duration;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantLock;

public class MekaDBClient {
  private final String host;
  private final int port;
  private final boolean withTls;
  private MekaDBClientGrpc.MekaDBClientStub asyncStub;
  private final AtomicLong requestIdCounter = new AtomicLong(0);
  private final Map<Long, CompletableFuture<List<Map<String, Object>>>> pendingRequests = new ConcurrentHashMap<>();
  private final ObjectMapper objectMapper = new ObjectMapper();
  private StreamObserver<SqlRequest> requestObserver;
  private int connectAttempts = 0;
  private DateTime lastRes, lastReq;
  private final ReentrantLock lock = new ReentrantLock();
  private ManagedChannel channel;
  private MekaDBClientGrpc.MekaDBClientBlockingStub blockingStub;

  public MekaDBClient() {
    this("mekadb.hypi.app", 443, true);
  }

  public MekaDBClient(String host, int port, boolean withTls) {
    this.host = host;
    this.port = port;
    this.withTls = withTls;
    connect();
    initStream();
  }

  private void connect() {
    channel = withTls ? ManagedChannelBuilder.forAddress(host, port)
        .useTransportSecurity()
        .build() : ManagedChannelBuilder.forAddress(host, port)
        .usePlaintext()
        .build();
    this.blockingStub = MekaDBClientGrpc.newBlockingStub(channel);
    this.asyncStub = MekaDBClientGrpc.newStub(channel);
  }

  private void initStream() {
    requestObserver = asyncStub.sqlWithJsonResponse(new StreamObserver<SqlResponse>() {
      @Override
      public void onNext(SqlResponse sqlResponse) {
        lastRes = DateTime.now();
        handleResponse(sqlResponse);
      }

      @Override
      public void onError(Throwable t) {
        reconnect(t);
      }

      @Override
      public void onCompleted() {
        reconnect(new ConnectionFailed());
      }
    });
  }

  public void shutdown() throws InterruptedException {
    channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
  }

  public AuthCtx login(String username, String password, String database) {
    return login(username, password, database, null);
  }

  public AuthCtx login(String username, String password, String database, String schema) {
    AuthReq.Builder authReqBuilder = AuthReq.newBuilder()
        .setUsername(username)
        .setPassword(password)
        .setDatabase(database);
    if (schema != null && !schema.isEmpty()) {
      authReqBuilder.setSchema(schema);
    }
    AuthReq authReq = authReqBuilder.build();
    return blockingStub.authenticate(authReq);
  }

  public CompletableFuture<List<Map<String, Object>>> query(AuthCtx creds, String sql) {
    return query(creds, sql, null);
  }

  public CompletableFuture<List<Map<String, Object>>> query(AuthCtx creds, String sql, Map<String, Object> params) {
    //rough heuristics to detect broken connections
    if (
        !pendingRequests.isEmpty() &&
            lastReq != null &&
            lastRes != null &&
            new Duration(lastReq, DateTime.now()).getStandardSeconds() <= 30 &&
            new Duration(lastRes, DateTime.now()).getStandardSeconds() > 40
    ) {
      reconnect(new ConnectionFailed());
    }
    CompletableFuture<List<Map<String, Object>>> future = new CompletableFuture<>();
    long requestId = requestIdCounter.incrementAndGet();
    pendingRequests.put(requestId, future);

    SqlRequest.Builder requestBuilder = SqlRequest.newBuilder()
        .setRequestId(requestId)
        .setAuth(creds)
        .setQuery(sql);

    NamedQueryPlaceHolder.Builder paramsBuilder = NamedQueryPlaceHolder.newBuilder();
    if (params != null) {
      params.forEach((key, value) -> {
        PlaceholderValue.Builder valueBuilder = PlaceholderValue.newBuilder();
        if (value instanceof Integer) {
          valueBuilder.setI32T((Integer) value);
        } else if (value instanceof Long) {
          valueBuilder.setI64T((Long) value);
        } else if (value instanceof Boolean) {
          valueBuilder.setBoolT((Boolean) value);
        } else if (value instanceof Double) {
          valueBuilder.setDoubleT((Double) value);
        } else if (value instanceof Float) {
          valueBuilder.setFloatT((Float) value);
        } else if (value instanceof java.sql.Date) {
          valueBuilder.setTimestampMillis(((java.sql.Date) value).getTime());
        } else if (value instanceof java.util.Date) {
          valueBuilder.setTimestampMillis(((java.util.Date) value).getTime());
        } else if (value instanceof String) {
          valueBuilder.setStrT((String) value);
        }
        paramsBuilder.addValues(PlaceholderPair.newBuilder().setName(key).setValue(valueBuilder));
      });
    }
    requestBuilder.setNamed(paramsBuilder);

    requestObserver.onNext(requestBuilder.build());
    return future;
  }

  private void reconnect(Throwable e) {
    try {
      lock.lock();
      Thread.sleep(1000);
      if (connectAttempts++ > 3 && channel != null) {
        channel.shutdownNow().awaitTermination(10, TimeUnit.SECONDS);
      }
      failPendingRequests(e);
      connect();
      initStream();
    } catch (InterruptedException ex) {
      failPendingRequests(ex);
    } finally {
      if (lock.isHeldByCurrentThread()) {
        lock.unlock();
      }
    }
  }

  private void failPendingRequests(Throwable e) {
    for (CompletableFuture<List<Map<String, Object>>> value : pendingRequests.values()) {
      value.completeExceptionally(e);
    }
    pendingRequests.clear();
  }

  private void handleResponse(SqlResponse sqlResponse) {
    CompletableFuture<List<Map<String, Object>>> future = pendingRequests.remove(sqlResponse.getRequestId());
    if (future == null) {
      return; // No pending future for this request
    }

    if (sqlResponse.getPayloadCase() == SqlResponse.PayloadCase.RESPONSE) {
      try {
        List<Map<String, Object>> result = objectMapper.readValue(sqlResponse.getResponse(), new TypeReference<List<Map<String, Object>>>() {
        });
        future.complete(result);
      } catch (Exception e) {
        future.completeExceptionally(e);
      }
    } else if (sqlResponse.getPayloadCase() == SqlResponse.PayloadCase.ERROR) {
      future.completeExceptionally(new Exception(sqlResponse.getError().getMessage()));
    }
  }

  public static class ConnectionFailed extends RuntimeException {

  }

  public static void main(String[] args) throws ExecutionException, InterruptedException {
    MekaDBClient client = new MekaDBClient();
    AuthCtx auth = client.login("<username>", "<password>", "<database>");
    List<Map<String, Object>> createTableRes = client.query(auth, "CREATE TABLE IF NOT EXISTS user(username VARCHAR, pass VARCHAR, PRIMARY KEY (username))").get();
    System.out.println(createTableRes);
    List<Map<String, Object>> insertRes = client.query(auth, "INSERT INTO user(username,pass) VALUES('courtney','pass1'),('damion','pass2')").get();
    System.out.println(insertRes);
    Map<String, Object> params = new LinkedHashMap<>();
    params.put("pass", "pass1");
    List<Map<String, Object>> rows = client.query(auth, "SELECT * FROM user where pass = :pass", params).get();
    System.out.println(rows);
  }
}
