package com.github.romualdrousseau.any2json.classifier;

import com.github.romualdrousseau.any2json.Model;
import com.github.romualdrousseau.any2json.TagClassifier;
import com.github.romualdrousseau.any2json.util.Disk;
import com.github.romualdrousseau.shuju.commons.PythonManager;
import com.github.romualdrousseau.shuju.json.JSON;
import com.github.romualdrousseau.shuju.json.JSONArray;
import com.github.romualdrousseau.shuju.preprocessing.Text;
import com.github.romualdrousseau.shuju.preprocessing.hasher.VocabularyHasher;
import com.github.romualdrousseau.shuju.preprocessing.tokenizer.NgramTokenizer;
import com.github.romualdrousseau.shuju.preprocessing.tokenizer.ShingleTokenizer;
import com.github.romualdrousseau.shuju.types.Tensor;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileAttribute;
import java.util.AbstractMap;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.SessionFunction;
import org.tensorflow.types.TFloat32;

/* loaded from: input_file:com/github/romualdrousseau/any2json/classifier/NetTagClassifier.class */
public class NetTagClassifier implements TagClassifier {
    private static final int IN_ENTITY_SIZE = 10;
    private static final int IN_NAME_SIZE = 10;
    private static final int IN_CONTEXT_SIZE = 100;
    private static final int OUT_TAG_SIZE = 64;
    private final Model model;
    private final List<String> vocabulary;
    private final int ngrams;
    private final int wordMinSize;
    private final List<String> lexicon;
    private final Path modelPath;
    private final Text.ITokenizer tokenizer;
    private final Text.IHasher hasher;
    private final SavedModelBundle tagClassifierModel;
    private final SessionFunction tagClassifierFunc;
    private final boolean modelIsTemp;

    public NetTagClassifier(Model model, List<String> list, int i, int i2, List<String> list2, Path path) {
        this.model = model;
        this.vocabulary = list;
        this.ngrams = i;
        this.wordMinSize = i2;
        this.lexicon = list2;
        this.tokenizer = this.ngrams == 0 ? new ShingleTokenizer(this.lexicon, this.wordMinSize) : new NgramTokenizer(this.ngrams);
        this.hasher = new VocabularyHasher(this.vocabulary);
        this.modelPath = path;
        if (path.toFile().exists()) {
            this.tagClassifierModel = SavedModelBundle.load(path.toString(), new String[]{"serve"});
            this.tagClassifierFunc = this.tagClassifierModel.function("serving_default");
        } else {
            this.tagClassifierModel = null;
            this.tagClassifierFunc = null;
        }
        this.modelIsTemp = false;
        this.model.toJSON().setArray("vocabulary", JSON.arrayOf(this.vocabulary));
        this.model.toJSON().setInt("ngram", this.ngrams);
        this.model.toJSON().setInt("wordMinSize", this.wordMinSize);
        this.model.toJSON().setArray("lexicon", JSON.arrayOf(this.lexicon));
        this.model.toJSON().setString("model", modelToJSONString(this.modelPath));
    }

    public NetTagClassifier(Model model) {
        this.model = model;
        this.vocabulary = JSON.streamOf(model.toJSON().getArray("vocabulary")).toList();
        this.ngrams = model.toJSON().getInt("ngrams");
        this.wordMinSize = model.toJSON().getInt("wordMinSize");
        this.lexicon = JSON.streamOf(model.toJSON().getArray("lexicon")).toList();
        this.tokenizer = this.ngrams == 0 ? new ShingleTokenizer(this.lexicon, this.wordMinSize) : new NgramTokenizer(this.ngrams);
        this.hasher = new VocabularyHasher(this.vocabulary);
        this.modelPath = JSONStringToModel(model.toJSON().getString("model"));
        if (this.modelPath.toFile().exists()) {
            this.tagClassifierModel = SavedModelBundle.load(this.modelPath.toString(), new String[]{"serve"});
            this.tagClassifierFunc = this.tagClassifierModel.function("serving_default");
        } else {
            this.tagClassifierModel = null;
            this.tagClassifierFunc = null;
        }
        this.modelIsTemp = true;
    }

    public void close() throws Exception {
        if (this.modelIsTemp) {
            Disk.deleteDir(this.modelPath);
        }
        if (this.tagClassifierModel != null) {
            this.tagClassifierModel.close();
        }
    }

    public String predict(String str, List<String> list, List<String> list2) {
        return predict(buildPredictEntry(str, list, list2));
    }

    public AbstractMap.SimpleImmutableEntry<List<Integer>, List<Integer>> buildTrainingEntry(String str, List<String> list, List<String> list2, String str2) {
        return new AbstractMap.SimpleImmutableEntry<>(Text.pad_sequence(Text.to_categorical(str2, this.model.getTagList()), OUT_TAG_SIZE), buildPredictEntry(str, list, list2));
    }

    public Process fit(List<List<Integer>> list, List<List<Integer>> list2) throws IOException, InterruptedException, URISyntaxException {
        String format = String.format("%d,%d,%d,%d", 10, 10, Integer.valueOf(IN_CONTEXT_SIZE), Integer.valueOf(OUT_TAG_SIZE));
        Path absolutePath = Files.createTempDirectory("any2json", new FileAttribute[0]).toAbsolutePath();
        JSONArray newArray = JSON.newArray();
        list.forEach(list3 -> {
            newArray.append(JSON.arrayOf(list3.toString()));
        });
        JSON.saveArray(newArray, absolutePath.resolve("training.json"));
        JSONArray newArray2 = JSON.newArray();
        list2.forEach(list4 -> {
            newArray2.append(JSON.arrayOf(list4.toString()));
        });
        JSON.saveArray(newArray2, absolutePath.resolve("validation.json"));
        return new PythonManager("kernels.tf").setEnviroment(Map.of("TF_CPP_MIN_VLOG_LEVEL", "3", "TF_CPP_MIN_LOG_LEVEL", "3")).run(new String[]{"-V " + this.vocabulary.size(), "-s " + format, "-t " + absolutePath, "-m " + this.modelPath});
    }

    private List<Integer> buildPredictEntry(String str, List<String> list, List<String> list2) {
        return Stream.of((Object[]) new List[]{Text.pad_sequence(Text.to_categorical(list, this.model.getEntityList()), 10).subList(0, 10), Text.pad_sequence(Text.one_hot(str, this.model.getFilters(), this.tokenizer, this.hasher), 10).subList(0, 10), Text.pad_sequence(list2.stream().filter(str2 -> {
            return !str2.equals(str);
        }).flatMap(str3 -> {
            return Text.one_hot(str3, this.model.getFilters(), this.tokenizer, this.hasher).stream();
        }).distinct().sorted().toList(), IN_CONTEXT_SIZE).subList(0, IN_CONTEXT_SIZE)}).flatMap((v0) -> {
            return v0.stream();
        }).toList();
    }

    private String predict(List<Integer> list) {
        if (this.tagClassifierFunc == null) {
            return (String) this.model.getTagList().get(0);
        }
        return (String) this.model.getTagList().get((int) Tensor.of((TFloat32) this.tagClassifierFunc.call(Map.of("entity_input", Tensor.of(list.subList(0, 10).stream().mapToDouble(num -> {
            return num.intValue();
        }).toArray()).reshape(new int[]{1, -1}).toTFloat32(), "name_input", Tensor.of(list.subList(10, 20).stream().mapToDouble(num2 -> {
            return num2.intValue();
        }).toArray()).reshape(new int[]{1, -1}).toTFloat32(), "context_input", Tensor.of(list.subList(20, 120).stream().mapToDouble(num3 -> {
            return num3.intValue();
        }).toArray()).reshape(new int[]{1, -1}).toTFloat32())).get("tag_output").get()).argmax(1).item(0));
    }

    private String modelToJSONString(Path path) {
        try {
            Path createTempFile = Files.createTempFile("model-", ".zip", new FileAttribute[0]);
            Disk.zipDir(path, createTempFile.toFile());
            return Base64.getEncoder().encodeToString(Files.readAllBytes(createTempFile));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private Path JSONStringToModel(String str) {
        try {
            Path createTempFile = Files.createTempFile("model-", ".zip", new FileAttribute[0]);
            Path createTempDirectory = Files.createTempDirectory("model-", new FileAttribute[0]);
            Files.write(createTempFile, Base64.getDecoder().decode(str), StandardOpenOption.CREATE);
            Disk.unzipDir(createTempFile, createTempDirectory);
            return createTempDirectory;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
