/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;

public class OneHotter {
    private String oheSplitter = "_OHE_";
    private String proxySplitter = "_OHEPROXY";
    private final Random rn;
    private final String alpha = "abcdefghijklmnopqrstuvwxyz";
    private Map<String, LinkedHashSet<Value>> categoricals = new HashMap<String, LinkedHashSet<Value>>();

    private void addFeatureValue(Feature f, boolean initialization) {
        if (!initialization && !this.categoricals.containsKey(f.getName())) {
            throw new IllegalArgumentException(String.format("Feature name %s was not present in initialized dataset", f.getName()));
        }
        if (f.getName().contains(this.oheSplitter) || f.getName().contains(this.proxySplitter)) {
            String randString = String.valueOf("abcdefghijklmnopqrstuvwxyz".charAt(this.rn.nextInt(26)));
            this.oheSplitter = this.oheSplitter.substring(0, this.oheSplitter.length() - 1) + "_" + randString + "_";
            this.proxySplitter = this.proxySplitter + "_" + randString;
        }
        if (!this.categoricals.containsKey(f.getName())) {
            this.categoricals.put(f.getName(), new LinkedHashSet<Value>(List.of(f.getValue())));
        } else {
            LinkedHashSet<Value> currVals = this.categoricals.get(f.getName());
            currVals.add(f.getValue());
            this.categoricals.put(f.getName(), currVals);
        }
    }

    public OneHotter(List<PredictionInput> pis, PerturbationContext pc) {
        this.rn = pc.getRandom();
        for (PredictionInput pi : pis) {
            for (Feature f : pi.getFeatures()) {
                if (f.getType() != Type.CATEGORICAL) continue;
                this.addFeatureValue(f, true);
            }
        }
    }

    private void featureGenerator(Feature prototype, Value[] comparedVals, boolean proxy, List<Feature> encodedFeatures) {
        for (int i = 0; i < comparedVals.length; ++i) {
            Feature newFeature;
            if (proxy && comparedVals[i].equals(prototype.getValue())) {
                newFeature = new Feature(prototype.getName() + this.proxySplitter, Type.NUMBER, new Value(i));
                encodedFeatures.add(newFeature);
                break;
            }
            if (proxy) continue;
            newFeature = new Feature(prototype.getName() + this.oheSplitter + i + this.oheSplitter + comparedVals[i], Type.NUMBER, new Value(prototype.getValue().equals(comparedVals[i]) ? 1 : 0));
            encodedFeatures.add(newFeature);
        }
    }

    public PredictionInput oneHotEncode(PredictionInput pi, boolean proxy) {
        if (this.categoricals.isEmpty()) {
            return pi;
        }
        ArrayList<Feature> encodedFeatures = new ArrayList<Feature>();
        for (Feature f : pi.getFeatures()) {
            if (this.categoricals.containsKey(f.getName()) && f.getType() == Type.CATEGORICAL) {
                if (!this.categoricals.get(f.getName()).contains(f.getValue())) {
                    this.addFeatureValue(f, false);
                }
                Value[] comparedVals = this.categoricals.get(f.getName()).toArray(new Value[0]);
                this.featureGenerator(f, comparedVals, proxy, encodedFeatures);
                continue;
            }
            encodedFeatures.add(f);
        }
        return new PredictionInput(encodedFeatures);
    }

    public List<PredictionInput> oneHotEncode(List<PredictionInput> pis, boolean proxy) {
        if (this.categoricals.isEmpty()) {
            return pis;
        }
        ArrayList<PredictionInput> encodedPIs = new ArrayList<PredictionInput>();
        for (PredictionInput pi : pis) {
            encodedPIs.add(this.oneHotEncode(pi, proxy));
        }
        return encodedPIs;
    }

    public PredictionInput oneHotDecode(PredictionInput pi, boolean proxy) {
        if (this.categoricals.isEmpty()) {
            return pi;
        }
        String proxyIndicator = proxy ? this.proxySplitter : this.oheSplitter;
        ArrayList<Feature> decodedFeatures = new ArrayList<Feature>();
        for (Feature f : pi.getFeatures()) {
            if (f.getName().contains(proxyIndicator)) {
                if (proxy) {
                    String parentFeature = f.getName().split(this.proxySplitter)[0];
                    decodedFeatures.add(new Feature(parentFeature, Type.CATEGORICAL, (Value)this.categoricals.get(parentFeature).toArray()[(int)f.getValue().asNumber()]));
                    continue;
                }
                if (f.getValue().asNumber() != 1.0) continue;
                String[] splitName = f.getName().split(this.oheSplitter);
                String parentFeature = splitName[0];
                int categoricalValue = Integer.parseInt(splitName[1]);
                decodedFeatures.add(new Feature(parentFeature, Type.CATEGORICAL, (Value)this.categoricals.get(parentFeature).toArray()[categoricalValue]));
                continue;
            }
            decodedFeatures.add(f);
        }
        return new PredictionInput(decodedFeatures);
    }

    public List<PredictionInput> oneHotDecode(List<PredictionInput> pis, boolean proxy) {
        if (this.categoricals.isEmpty()) {
            return pis;
        }
        ArrayList<PredictionInput> decodedPIs = new ArrayList<PredictionInput>();
        for (PredictionInput pi : pis) {
            decodedPIs.add(this.oneHotDecode(pi, proxy));
        }
        return decodedPIs;
    }
}

