/*
 * Decompiled with CFR 0.152.
 */
package io.moderne.ai;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Function;
import kong.unirest.HttpRequestWithBody;
import kong.unirest.HttpResponse;
import kong.unirest.Unirest;
import kong.unirest.UnirestException;
import org.openrewrite.internal.lang.Nullable;

public class RelatedModelClient {
    private static final ExecutorService EXECUTOR_SERVICE = Executors.newFixedThreadPool(3);
    private static final Path MODELS_DIR = Paths.get(System.getProperty("user.home") + "/.moderne/models", new String[0]);
    @Nullable
    private static RelatedModelClient INSTANCE;
    private final Map<Embedding, Boolean> embeddingCache = Collections.synchronizedMap(new LinkedHashMap<Embedding, Boolean>(){

        @Override
        protected boolean removeEldestEntry(Map.Entry<Embedding, Boolean> eldest) {
            return this.size() > 1000;
        }
    });

    public static synchronized RelatedModelClient getInstance() {
        if (INSTANCE == null && (INSTANCE = new RelatedModelClient()).checkForUpRequest() != 200) {
            String cmd = String.format("/usr/bin/python3 'import gradio\ngradio.'", MODELS_DIR);
            try {
                Process process = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", cmd});
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
            INSTANCE.start();
        }
        return INSTANCE;
    }

    private void start() {
        Path pyLauncher = MODELS_DIR.resolve("get_related.py");
        try {
            Files.copy(Objects.requireNonNull(RelatedModelClient.class.getResourceAsStream("/get_related.py")), pyLauncher, StandardCopyOption.REPLACE_EXISTING);
            StringWriter sw = new StringWriter();
            PrintWriter procOut = new PrintWriter(sw);
            String cmd = String.format("/usr/bin/python3 %s/get_related.py", MODELS_DIR);
            Process proc = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", cmd});
            EXECUTOR_SERVICE.submit(() -> {
                new BufferedReader(new InputStreamReader(proc.getInputStream())).lines().forEach(procOut::println);
                new BufferedReader(new InputStreamReader(proc.getErrorStream())).lines().forEach(procOut::println);
            });
            if (!this.checkForUp(proc)) {
                throw new IllegalStateException("Unable to start model daemon. Output of process is:\n" + sw);
            }
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private boolean checkForUp(Process proc) {
        for (int i = 0; i < 60; ++i) {
            try {
                if (!proc.isAlive() && proc.exitValue() != 0) {
                    return false;
                }
                if (this.checkForUpRequest() == 200) {
                    return true;
                }
                Thread.sleep(1000L);
                continue;
            }
            catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
        return false;
    }

    private int checkForUpRequest() {
        try {
            HttpResponse response = Unirest.head((String)"http://127.0.0.1:7869").asString();
            return response.getStatus();
        }
        catch (UnirestException e) {
            return 523;
        }
    }

    public Relatedness getRelatedness(String t1, String t2, double threshold) {
        ArrayList<Duration> timings = new ArrayList<Duration>(2);
        Embedding embedding = new Embedding(t1, t2, threshold);
        boolean b1 = this.embeddingCache.computeIfAbsent(embedding, this.timeEmbedding(timings));
        return new Relatedness(b1, timings);
    }

    private Function<Embedding, Boolean> timeEmbedding(List<Duration> timings) {
        return t -> {
            long start = System.nanoTime();
            boolean b = this.getEmbedding(((Embedding)t).t1, ((Embedding)t).t2, ((Embedding)t).threshold);
            if (timings.isEmpty()) {
                timings.add(Duration.ofNanos(System.nanoTime() - start));
            }
            return b;
        };
    }

    public boolean getEmbedding(String s1, String s2, double threshold) {
        HttpResponse response = ((HttpRequestWithBody)Unirest.post((String)"http://127.0.0.1:7869/run/predict").header("Content-Type", "application/json")).body((Object)new GradioRequest(new Object[]{s1, s2, threshold})).asObject(GradioResponse.class);
        if (!response.isSuccess()) {
            throw new IllegalStateException("Unable to get embedding. HTTP " + response.getStatus());
        }
        return ((GradioResponse)response.getBody()).isRelated();
    }

    static {
        if (!Files.exists(MODELS_DIR, new LinkOption[0]) && !MODELS_DIR.toFile().mkdirs()) {
            throw new IllegalStateException("Unable to create models directory at " + MODELS_DIR);
        }
    }

    public static final class Embedding {
        private final String t1;
        private final String t2;
        private final double threshold;

        public Embedding(String t1, String t2, double threshold) {
            this.t1 = t1;
            this.t2 = t2;
            this.threshold = threshold;
        }

        public String getT1() {
            return this.t1;
        }

        public String getT2() {
            return this.t2;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Embedding)) {
                return false;
            }
            Embedding other = (Embedding)o;
            if (Double.compare(this.getThreshold(), other.getThreshold()) != 0) {
                return false;
            }
            String this$t1 = this.getT1();
            String other$t1 = other.getT1();
            if (this$t1 == null ? other$t1 != null : !this$t1.equals(other$t1)) {
                return false;
            }
            String this$t2 = this.getT2();
            String other$t2 = other.getT2();
            return !(this$t2 == null ? other$t2 != null : !this$t2.equals(other$t2));
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $threshold = Double.doubleToLongBits(this.getThreshold());
            result = result * 59 + (int)($threshold >>> 32 ^ $threshold);
            String $t1 = this.getT1();
            result = result * 59 + ($t1 == null ? 43 : $t1.hashCode());
            String $t2 = this.getT2();
            result = result * 59 + ($t2 == null ? 43 : $t2.hashCode());
            return result;
        }

        public String toString() {
            return "RelatedModelClient.Embedding(t1=" + this.getT1() + ", t2=" + this.getT2() + ", threshold=" + this.getThreshold() + ")";
        }
    }

    public static final class Relatedness {
        private final boolean isRelated;
        private final List<Duration> embeddingTimings;

        public Relatedness(boolean isRelated, List<Duration> embeddingTimings) {
            this.isRelated = isRelated;
            this.embeddingTimings = embeddingTimings;
        }

        public boolean isRelated() {
            return this.isRelated;
        }

        public List<Duration> getEmbeddingTimings() {
            return this.embeddingTimings;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Relatedness)) {
                return false;
            }
            Relatedness other = (Relatedness)o;
            if (this.isRelated() != other.isRelated()) {
                return false;
            }
            List<Duration> this$embeddingTimings = this.getEmbeddingTimings();
            List<Duration> other$embeddingTimings = other.getEmbeddingTimings();
            return !(this$embeddingTimings == null ? other$embeddingTimings != null : !((Object)this$embeddingTimings).equals(other$embeddingTimings));
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + (this.isRelated() ? 79 : 97);
            List<Duration> $embeddingTimings = this.getEmbeddingTimings();
            result = result * 59 + ($embeddingTimings == null ? 43 : ((Object)$embeddingTimings).hashCode());
            return result;
        }

        public String toString() {
            return "RelatedModelClient.Relatedness(isRelated=" + this.isRelated() + ", embeddingTimings=" + this.getEmbeddingTimings() + ")";
        }
    }

    private static final class GradioRequest {
        private final Object[] data;

        public GradioRequest(Object[] data) {
            this.data = data;
        }

        public Object[] getData() {
            return this.data;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof GradioRequest)) {
                return false;
            }
            GradioRequest other = (GradioRequest)o;
            return Arrays.deepEquals(this.getData(), other.getData());
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + Arrays.deepHashCode(this.getData());
            return result;
        }

        public String toString() {
            return "RelatedModelClient.GradioRequest(data=" + Arrays.deepToString(this.getData()) + ")";
        }
    }

    private static final class GradioResponse {
        private final String[] data;

        public boolean isRelated() {
            return this.data[0].equals("True");
        }

        public GradioResponse(String[] data) {
            this.data = data;
        }

        public String[] getData() {
            return this.data;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof GradioResponse)) {
                return false;
            }
            GradioResponse other = (GradioResponse)o;
            return Arrays.deepEquals(this.getData(), other.getData());
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + Arrays.deepHashCode(this.getData());
            return result;
        }

        public String toString() {
            return "RelatedModelClient.GradioResponse(data=" + Arrays.deepToString(this.getData()) + ")";
        }
    }
}

