/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.model.config;

import org.encog.EncogError;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.ml.data.versatile.VersatileMLDataSet;
import org.encog.ml.data.versatile.normalizers.strategies.BasicNormalizationStrategy;
import org.encog.ml.data.versatile.normalizers.strategies.NormalizationStrategy;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.model.config.MethodConfig;
import org.encog.neural.networks.BasicNetwork;

public class FeedforwardConfig
implements MethodConfig {
    @Override
    public String getMethodName() {
        return "feedforward";
    }

    @Override
    public String suggestModelArchitecture(VersatileMLDataSet dataset) {
        int inputColumns = dataset.getNormHelper().getInputColumns().size();
        int outputColumns = dataset.getNormHelper().getOutputColumns().size();
        int hiddenCount = (int)((double)(inputColumns + outputColumns) * 1.5);
        StringBuilder result = new StringBuilder();
        result.append("?:B->TANH->");
        result.append(hiddenCount);
        result.append(":B->TANH->?");
        return result.toString();
    }

    @Override
    public NormalizationStrategy suggestNormalizationStrategy(VersatileMLDataSet dataset, String architecture) {
        double inputLow = -1.0;
        double inputHigh = 1.0;
        double outputLow = -1.0;
        double outputHigh = 1.0;
        MLMethodFactory methodFactory = new MLMethodFactory();
        BasicNetwork network = (BasicNetwork)methodFactory.create(this.getMethodName(), architecture, 1, 1);
        if (network.getLayerCount() < 1) {
            throw new EncogError("Neural network does not have an output layer.");
        }
        ActivationFunction outputFunction = network.getActivation(network.getLayerCount() - 1);
        double[] d = new double[]{-1000.0, -100.0, -50.0};
        outputFunction.activationFunction(d, 0, d.length);
        if (d[0] > 0.0 && d[1] > 0.0 && d[2] > 0.0) {
            outputLow = 0.0;
        }
        ActivationFunction inputFunction = network.getActivation(1);
        double[] d2 = new double[]{-1000.0, -100.0, -50.0};
        inputFunction.activationFunction(d2, 0, d2.length);
        if (d2[0] > 0.0 && d2[1] > 0.0 && d2[2] > 0.0) {
            inputLow = 0.0;
        }
        BasicNormalizationStrategy result = new BasicNormalizationStrategy(inputLow, inputHigh, outputLow, outputHigh);
        return result;
    }

    @Override
    public String suggestTrainingType() {
        return "rprop";
    }

    @Override
    public String suggestTrainingArgs(String trainingType) {
        return "";
    }

    @Override
    public int determineOutputCount(VersatileMLDataSet dataset) {
        return dataset.getNormHelper().calculateNormalizedOutputCount();
    }
}

