/*
 * Decompiled with CFR 0.152.
 */
package nak.model;

import java.io.IOException;
import java.util.Map;
import nak.maxent.GIS;
import nak.model.AbstractDataIndexer;
import nak.model.AbstractModel;
import nak.model.EventStream;
import nak.model.HashSumEventStream;
import nak.model.OnePassDataIndexer;
import nak.model.SequenceStream;
import nak.model.TwoPassDataIndexer;
import nak.perceptron.PerceptronTrainer;
import nak.perceptron.SimplePerceptronSequenceTrainer;
import nak.quasinewton.QNTrainer;

public class TrainUtil {
    public static final String ALGORITHM_PARAM = "Algorithm";
    public static final String MAXENT_VALUE = "MAXENT";
    public static final String MAXENT_QN_VALUE = "MAXENT_QN_EXPERIMENTAL";
    public static final String PERCEPTRON_VALUE = "PERCEPTRON";
    public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
    public static final String CUTOFF_PARAM = "Cutoff";
    private static final int CUTOFF_DEFAULT = 5;
    public static final String ITERATIONS_PARAM = "Iterations";
    private static final int ITERATIONS_DEFAULT = 100;
    public static final String DATA_INDEXER_PARAM = "DataIndexer";
    public static final String DATA_INDEXER_ONE_PASS_VALUE = "OnePass";
    public static final String DATA_INDEXER_TWO_PASS_VALUE = "TwoPass";

    private static String getStringParam(Map<String, String> map, String string, String string2, Map<String, String> map2) {
        String string3 = map.get(string);
        if (string3 == null) {
            string3 = string2;
        }
        if (map2 != null) {
            map2.put(string, string3);
        }
        return string3;
    }

    private static int getIntParam(Map<String, String> map, String string, int n, Map<String, String> map2) {
        String string2 = map.get(string);
        if (string2 != null) {
            return Integer.parseInt(string2);
        }
        return n;
    }

    private static double getDoubleParam(Map<String, String> map, String string, double d, Map<String, String> map2) {
        String string2 = map.get(string);
        if (string2 != null) {
            return Double.parseDouble(string2);
        }
        return d;
    }

    private static boolean getBooleanParam(Map<String, String> map, String string, boolean bl, Map<String, String> map2) {
        String string2 = map.get(string);
        if (string2 != null) {
            return Boolean.parseBoolean(string2);
        }
        return bl;
    }

    public static boolean isValid(Map<String, String> map) {
        String string;
        String string2 = map.get(ALGORITHM_PARAM);
        if (!(string2 == null || MAXENT_VALUE.equals(string2) || MAXENT_QN_VALUE.equals(string2) || PERCEPTRON_VALUE.equals(string2) || PERCEPTRON_SEQUENCE_VALUE.equals(string2))) {
            return false;
        }
        try {
            String string3;
            string = map.get(CUTOFF_PARAM);
            if (string != null) {
                Integer.parseInt(string);
            }
            if ((string3 = map.get(ITERATIONS_PARAM)) != null) {
                Integer.parseInt(string3);
            }
        }
        catch (NumberFormatException numberFormatException) {
            return false;
        }
        string = map.get(DATA_INDEXER_PARAM);
        return string == null || DATA_INDEXER_ONE_PASS_VALUE.equals(string) || DATA_INDEXER_TWO_PASS_VALUE.equals(string);
    }

    public static AbstractModel train(EventStream eventStream, Map<String, String> map, Map<String, String> map2) throws IOException {
        AbstractModel abstractModel;
        boolean bl;
        if (!TrainUtil.isValid(map)) {
            throw new IllegalArgumentException("trainParams are not valid!");
        }
        if (TrainUtil.isSequenceTraining(map)) {
            throw new IllegalArgumentException("sequence training is not supported by this method!");
        }
        String string = TrainUtil.getStringParam(map, ALGORITHM_PARAM, MAXENT_VALUE, map2);
        int n = TrainUtil.getIntParam(map, ITERATIONS_PARAM, 100, map2);
        int n2 = TrainUtil.getIntParam(map, CUTOFF_PARAM, 5, map2);
        if (MAXENT_VALUE.equals(string) || MAXENT_QN_VALUE.equals(string)) {
            bl = true;
        } else if (PERCEPTRON_VALUE.equals(string)) {
            bl = false;
        } else {
            throw new IllegalStateException("Unexpected algorithm name: " + string);
        }
        HashSumEventStream hashSumEventStream = new HashSumEventStream(eventStream);
        String string2 = TrainUtil.getStringParam(map, DATA_INDEXER_PARAM, DATA_INDEXER_TWO_PASS_VALUE, map2);
        AbstractDataIndexer abstractDataIndexer = null;
        if (DATA_INDEXER_ONE_PASS_VALUE.equals(string2)) {
            abstractDataIndexer = new OnePassDataIndexer(hashSumEventStream, n2, bl);
        } else if (DATA_INDEXER_TWO_PASS_VALUE.equals(string2)) {
            abstractDataIndexer = new TwoPassDataIndexer(hashSumEventStream, n2, bl);
        } else {
            throw new IllegalStateException("Unexpected data indexer name: " + string2);
        }
        if (MAXENT_VALUE.equals(string)) {
            int n3 = TrainUtil.getIntParam(map, "Threads", 1, map2);
            abstractModel = GIS.trainModel(n, abstractDataIndexer, true, false, null, 0, n3);
        } else if (MAXENT_QN_VALUE.equals(string)) {
            int n4 = TrainUtil.getIntParam(map, "numOfUpdates", 7, map2);
            int n5 = TrainUtil.getIntParam(map, "maxFctEval", 300, map2);
            abstractModel = new QNTrainer(n4, n5, true).trainModel(abstractDataIndexer);
        } else if (PERCEPTRON_VALUE.equals(string)) {
            boolean bl2 = TrainUtil.getBooleanParam(map, "UseAverage", true, map2);
            boolean bl3 = TrainUtil.getBooleanParam(map, "UseSkippedAveraging", false, map2);
            if (bl3) {
                bl2 = true;
            }
            double d = TrainUtil.getDoubleParam(map, "StepSizeDecrease", 0.0, map2);
            double d2 = TrainUtil.getDoubleParam(map, "Tolerance", 1.0E-5, map2);
            PerceptronTrainer perceptronTrainer = new PerceptronTrainer();
            perceptronTrainer.setSkippedAveraging(bl3);
            if (d > 0.0) {
                perceptronTrainer.setStepSizeDecrease(d);
            }
            perceptronTrainer.setTolerance(d2);
            abstractModel = perceptronTrainer.trainModel(n, abstractDataIndexer, n2, bl2);
        } else {
            throw new IllegalStateException("Algorithm not supported: " + string);
        }
        if (map2 != null) {
            map2.put("Training-Eventhash", hashSumEventStream.calculateHashSum().toString(16));
        }
        return abstractModel;
    }

    public static boolean isSequenceTraining(Map<String, String> map) {
        return PERCEPTRON_SEQUENCE_VALUE.equals(map.get(ALGORITHM_PARAM));
    }

    public static AbstractModel train(SequenceStream sequenceStream, Map<String, String> map, Map<String, String> map2) throws IOException {
        if (!TrainUtil.isValid(map)) {
            throw new IllegalArgumentException("trainParams are not valid!");
        }
        if (!TrainUtil.isSequenceTraining(map)) {
            throw new IllegalArgumentException("Algorithm must be a sequence algorithm!");
        }
        int n = TrainUtil.getIntParam(map, ITERATIONS_PARAM, 100, map2);
        int n2 = TrainUtil.getIntParam(map, CUTOFF_PARAM, 5, map2);
        boolean bl = TrainUtil.getBooleanParam(map, "UseAverage", true, map2);
        return new SimplePerceptronSequenceTrainer().trainModel(n, sequenceStream, n2, bl);
    }
}

