package com.github.romualdrousseau.any2json.classifier;

import com.github.romualdrousseau.any2json.HeaderTag;
import com.github.romualdrousseau.any2json.Model;
import com.github.romualdrousseau.any2json.util.Disk;
import com.github.romualdrousseau.any2json.util.TempFile;
import com.github.romualdrousseau.shuju.commons.PythonManager;
import com.github.romualdrousseau.shuju.json.JSON;
import com.github.romualdrousseau.shuju.json.JSONArray;
import com.github.romualdrousseau.shuju.preprocessing.Text;
import com.github.romualdrousseau.shuju.preprocessing.hasher.VocabularyHasher;
import com.github.romualdrousseau.shuju.preprocessing.tokenizer.NgramTokenizer;
import com.github.romualdrousseau.shuju.preprocessing.tokenizer.ShingleTokenizer;
import com.github.romualdrousseau.shuju.types.Tensor;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileAttribute;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.SessionFunction;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.types.TFloat32;

/* loaded from: input_file:com/github/romualdrousseau/any2json/classifier/NetTagClassifier.class */
public class NetTagClassifier extends SimpleTagClassifier {
    public static final int IN_ENTITY_SIZE = 10;
    public static final int IN_NAME_SIZE = 10;
    public static final int IN_CONTEXT_SIZE = 100;
    public static final int OUT_TAG_SIZE = 64;
    private final List<String> vocabulary;
    private final int ngrams;
    private final int wordMinSize;
    private final List<String> lexicon;
    private final Text.ITokenizer tokenizer;
    private final Text.IHasher hasher;
    private final boolean isModelTemp;
    private Model model;
    private Path modelPath;
    private SavedModelBundle tagClassifierModel;
    private SessionFunction tagClassifierFunc;

    public NetTagClassifier(List<String> list, int i, int i2, List<String> list2, Path path) {
        super((Model) null);
        this.vocabulary = list;
        this.ngrams = i;
        this.wordMinSize = i2;
        this.lexicon = list2;
        this.tokenizer = i == 0 ? new ShingleTokenizer(this.lexicon, this.wordMinSize) : new NgramTokenizer(i);
        this.hasher = new VocabularyHasher(this.vocabulary);
        this.modelPath = path;
        this.isModelTemp = false;
    }

    public NetTagClassifier(Model model) {
        super(model);
        this.model = model;
        this.vocabulary = JSON.streamOf(model.toJSON().getArray("vocabulary")).toList();
        this.ngrams = model.toJSON().getInt("ngrams");
        this.wordMinSize = model.toJSON().getInt("wordMinSize");
        this.lexicon = JSON.streamOf(model.toJSON().getArray("lexicon")).toList();
        this.tokenizer = this.ngrams == 0 ? new ShingleTokenizer(this.lexicon, this.wordMinSize) : new NgramTokenizer(this.ngrams);
        this.hasher = new VocabularyHasher(this.vocabulary);
        this.modelPath = null;
        this.isModelTemp = true;
    }

    public void close() throws Exception {
        if (this.tagClassifierModel != null) {
            this.tagClassifierModel.close();
            this.tagClassifierModel = null;
            this.tagClassifierFunc = null;
        }
        if (this.modelPath == null || !this.isModelTemp) {
            return;
        }
        Disk.deleteDir(this.modelPath);
    }

    public void updateModel(Model model) {
        this.model = model;
        this.model.toJSON().setArray("vocabulary", JSON.arrayOf(this.vocabulary));
        this.model.toJSON().setInt("ngrams", this.ngrams);
        this.model.toJSON().setInt("wordMinSize", this.wordMinSize);
        this.model.toJSON().setArray("lexicon", JSON.arrayOf(this.lexicon));
        this.model.toJSON().setString("model", modelToJSONString(this.modelPath));
    }

    public String predict(String str, List<String> list, List<String> list2) {
        ensureClassifierLoaded();
        if (this.tagClassifierModel == null) {
            return HeaderTag.None.getValue();
        }
        double[] array = createTrainingVector(str, list, list2).stream().mapToDouble(num -> {
            return num.intValue();
        }).toArray();
        return (String) this.model.getTagList().get((int) Tensor.of((TFloat32) this.tagClassifierFunc.call(Map.of("entity_input", Tensor.of(Arrays.stream(array, 0, 10).toArray()).reshape(new int[]{1, -1}).toTFloat32(), "name_input", Tensor.of(Arrays.stream(array, 10, 20).toArray()).reshape(new int[]{1, -1}).toTFloat32(), "context_input", Tensor.of(Arrays.stream(array, 20, 120).toArray()).reshape(new int[]{1, -1}).toTFloat32())).get("tag_output").get()).argmax(1).item(0));
    }

    public Process fit(List<TrainingEntry> list, List<TrainingEntry> list2) throws IOException, InterruptedException, URISyntaxException {
        if (this.modelPath == null) {
            this.modelPath = JSONStringToModelPath(this.model.toJSON().getString("model"));
        }
        if (this.tagClassifierModel != null) {
            this.tagClassifierModel.close();
            this.tagClassifierModel = null;
            this.tagClassifierFunc = null;
        }
        String format = String.format("%d,%d,%d,%d", 10, 10, 100, 64);
        Path absolutePath = Files.createTempDirectory("any2json", new FileAttribute[0]).toAbsolutePath();
        JSONArray newArray = JSON.newArray();
        list.forEach(trainingEntry -> {
            newArray.append(JSON.arrayOf(trainingEntry.getVector().toString()));
        });
        JSON.saveArray(newArray, absolutePath.resolve("training.json"));
        JSONArray newArray2 = JSON.newArray();
        list2.forEach(trainingEntry2 -> {
            newArray2.append(JSON.arrayOf(trainingEntry2.getVector().toString()));
        });
        JSON.saveArray(newArray2, absolutePath.resolve("validation.json"));
        return new PythonManager("kernels.tf").setEnviroment(Map.of("TF_CPP_MIN_VLOG_LEVEL", "3", "TF_CPP_MIN_LOG_LEVEL", "3")).run(new String[]{"-V " + this.vocabulary.size(), "-s " + format, "-t " + absolutePath, "-m " + this.modelPath});
    }

    public List<String> getVocabulary() {
        return this.vocabulary;
    }

    public List<String> getLexicon() {
        return this.lexicon;
    }

    public TrainingEntry createTrainingEntry(String str, List<String> list, List<String> list2, String str2) {
        return new TrainingEntry(createTrainingVector(str, list, list2), Text.pad_sequence(Text.to_categorical(str2, this.model.getTagList()), 64));
    }

    private List<Integer> createTrainingVector(String str, List<String> list, List<String> list2) {
        return Stream.of((Object[]) new List[]{Text.pad_sequence(Text.to_categorical(list, this.model.getEntityList()), 10).subList(0, 10), Text.pad_sequence(Text.one_hot(str, this.model.getFilters(), this.tokenizer, this.hasher), 10).subList(0, 10), Text.pad_sequence(list2.stream().filter(str2 -> {
            return !str2.equals(str);
        }).flatMap(str3 -> {
            return Text.one_hot(str3, this.model.getFilters(), this.tokenizer, this.hasher).stream();
        }).distinct().sorted().toList(), 100).subList(0, 100)}).flatMap((v0) -> {
            return v0.stream();
        }).toList();
    }

    private void ensureClassifierLoaded() {
        if (this.modelPath == null) {
            this.modelPath = JSONStringToModelPath(this.model.toJSON().getString("model"));
        }
        try {
            if (this.tagClassifierModel == null) {
                this.tagClassifierModel = SavedModelBundle.load(this.modelPath.toString(), new String[]{"serve"});
                this.tagClassifierFunc = this.tagClassifierModel.function("serving_default");
            }
        } catch (TensorFlowException e) {
            if (this.tagClassifierModel != null) {
                this.tagClassifierModel.close();
                this.tagClassifierModel = null;
                this.tagClassifierFunc = null;
            }
        }
    }

    private String modelToJSONString(Path path) {
        if (path == null || !path.toFile().exists()) {
            return "";
        }
        try {
            TempFile tempFile = new TempFile("model-", ".zip");
            try {
                Disk.zipDir(path, tempFile.getPath().toFile());
                String encodeToString = Base64.getEncoder().encodeToString(Files.readAllBytes(tempFile.getPath()));
                tempFile.close();
                return encodeToString;
            } finally {
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private Path JSONStringToModelPath(String str) {
        if (str == null) {
            return null;
        }
        try {
            TempFile tempFile = new TempFile("model-", ".zip");
            try {
                Files.write(tempFile.getPath(), Base64.getDecoder().decode(str), StandardOpenOption.CREATE);
                Path createTempDirectory = Files.createTempDirectory("model-", new FileAttribute[0]);
                Disk.unzipDir(tempFile.getPath(), createTempDirectory);
                createTempDirectory.toFile().deleteOnExit();
                tempFile.close();
                return createTempDirectory;
            } finally {
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }
}
