/*
 * Decompiled with CFR 0.152.
 */
package sklearn.neural_network;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.OpType;
import org.dmg.pmml.neural_network.Connection;
import org.dmg.pmml.neural_network.NeuralInput;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.HasArray;
import org.jpmml.sklearn.MatrixUtil;

public class NeuralNetworkUtil {
    private NeuralNetworkUtil() {
    }

    public static int getNumberOfFeatures(List<? extends HasArray> coefs) {
        HasArray input = coefs.get(0);
        int[] shape = input.getArrayShape();
        if (shape.length != 2) {
            throw new IllegalArgumentException();
        }
        return shape[0];
    }

    public static NeuralNetwork encodeNeuralNetwork(MiningFunction miningFunction, String activation, List<? extends HasArray> coefs, List<? extends HasArray> intercepts, Schema schema) {
        NeuralNetwork.ActivationFunction activationFunction = NeuralNetworkUtil.parseActivationFunction(activation);
        ClassDictUtil.checkSize(coefs, intercepts);
        NeuralInputs neuralInputs = new NeuralInputs();
        List features = schema.getFeatures();
        for (int column = 0; column < features.size(); ++column) {
            Feature feature = (Feature)features.get(column);
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setExpression((Expression)continuousFeature.ref());
            NeuralInput neuralInput = new NeuralInput().setId("0/" + (column + 1)).setDerivedField(derivedField);
            neuralInputs.addNeuralInputs(new NeuralInput[]{neuralInput});
        }
        List entities = neuralInputs.getNeuralInputs();
        ArrayList<NeuralLayer> neuralLayers = new ArrayList<NeuralLayer>();
        for (int layer = 0; layer < coefs.size(); ++layer) {
            HasArray coef = coefs.get(layer);
            HasArray intercept = intercepts.get(layer);
            int[] shape = coef.getArrayShape();
            int rows = shape[0];
            int columns = shape[1];
            ArrayList<Neuron> neurons = new ArrayList<Neuron>();
            List<?> interceptVector = intercept.getArrayContent();
            for (int column = 0; column < columns; ++column) {
                Neuron neuron = new Neuron().setId(layer + 1 + "/" + (column + 1));
                Double bias = ValueUtil.asDouble((Number)((Number)interceptVector.get(column)));
                if (!ValueUtil.isZero((Number)bias)) {
                    neuron.setBias(bias);
                }
                neurons.add(neuron);
            }
            List<?> coefMatrix = coef.getArrayContent();
            for (int row = 0; row < rows; ++row) {
                List<?> weights = MatrixUtil.getRow(coefMatrix, rows, columns, row);
                NeuralNetworkUtil.connect((Entity)entities.get(row), neurons, weights);
            }
            NeuralLayer neuralLayer = new NeuralLayer(neurons);
            if (layer == coefs.size() - 1) {
                neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY);
                switch (miningFunction) {
                    case REGRESSION: {
                        break;
                    }
                    case CLASSIFICATION: {
                        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
                        if (categoricalLabel.size() == 2) {
                            neuralLayers.add(neuralLayer);
                            neuralLayer = NeuralNetworkUtil.encodeLogisticTransform(NeuralNetworkUtil.getOnlyNeuron(neuralLayer));
                            neuralLayers.add(neuralLayer);
                            neuralLayer = NeuralNetworkUtil.encodeLabelBinarizerTransform(NeuralNetworkUtil.getOnlyNeuron(neuralLayer));
                            break;
                        }
                        if (categoricalLabel.size() > 2) {
                            neuralLayer.setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
                            break;
                        }
                        throw new IllegalArgumentException();
                    }
                }
            }
            entities = neuralLayer.getNeurons();
            neuralLayers.add(neuralLayer);
        }
        NeuralOutputs neuralOutputs = null;
        switch (miningFunction) {
            case REGRESSION: {
                neuralOutputs = NeuralNetworkUtil.encodeRegressionNeuralOutputs(entities, schema);
                break;
            }
            case CLASSIFICATION: {
                neuralOutputs = NeuralNetworkUtil.encodeClassificationNeuralOutputs(entities, schema);
                break;
            }
        }
        NeuralNetwork neuralNetwork = new NeuralNetwork(miningFunction, activationFunction, ModelUtil.createMiningSchema((Schema)schema), neuralInputs, neuralLayers).setNeuralOutputs(neuralOutputs);
        return neuralNetwork;
    }

    private static NeuralLayer encodeLogisticTransform(Neuron input) {
        NeuralLayer neuralLayer = new NeuralLayer().setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
        Neuron neuron = new Neuron().setId("logistic/1").setBias(Double.valueOf(0.0)).addConnections(new Connection[]{new Connection(input.getId(), 1.0)});
        neuralLayer.addNeurons(new Neuron[]{neuron});
        return neuralLayer;
    }

    private static NeuralLayer encodeLabelBinarizerTransform(Neuron input) {
        NeuralLayer neuralLayer = new NeuralLayer().setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY);
        Neuron noEventNeuron = new Neuron().setId("event/false").setBias(Double.valueOf(1.0)).addConnections(new Connection[]{new Connection(input.getId(), -1.0)});
        Neuron eventNeuron = new Neuron().setId("event/true").setBias(Double.valueOf(0.0)).addConnections(new Connection[]{new Connection(input.getId(), 1.0)});
        neuralLayer.addNeurons(new Neuron[]{noEventNeuron, eventNeuron});
        return neuralLayer;
    }

    private static NeuralOutputs encodeRegressionNeuralOutputs(List<? extends Entity> entities, Schema schema) {
        ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel();
        ClassDictUtil.checkSize(1, entities);
        Entity entity = (Entity)Iterables.getOnlyElement(entities);
        DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setExpression((Expression)new FieldRef(continuousLabel.getName()));
        NeuralOutput neuralOutput = new NeuralOutput().setOutputNeuron(entity.getId()).setDerivedField(derivedField);
        NeuralOutputs neuralOutputs = new NeuralOutputs().addNeuralOutputs(new NeuralOutput[]{neuralOutput});
        return neuralOutputs;
    }

    private static NeuralOutputs encodeClassificationNeuralOutputs(List<? extends Entity> entities, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        ClassDictUtil.checkSize(categoricalLabel.size(), entities);
        NeuralOutputs neuralOutputs = new NeuralOutputs();
        for (int i = 0; i < categoricalLabel.size(); ++i) {
            Entity entity = entities.get(i);
            DerivedField derivedField = new DerivedField(OpType.CATEGORICAL, DataType.STRING).setExpression((Expression)new NormDiscrete(categoricalLabel.getName(), categoricalLabel.getValue(i)));
            NeuralOutput neuralOutput = new NeuralOutput().setOutputNeuron(entity.getId()).setDerivedField(derivedField);
            neuralOutputs.addNeuralOutputs(new NeuralOutput[]{neuralOutput});
        }
        return neuralOutputs;
    }

    private static void connect(Entity input, List<Neuron> neurons, List<?> weights) {
        ClassDictUtil.checkSize(neurons, weights);
        for (int i = 0; i < neurons.size(); ++i) {
            Neuron neuron = neurons.get(i);
            Double weight = ValueUtil.asDouble((Number)((Number)weights.get(i)));
            neuron.addConnections(new Connection[]{new Connection(input.getId(), weight.doubleValue())});
        }
    }

    private static Neuron getOnlyNeuron(NeuralLayer neuralLayer) {
        List neurons = neuralLayer.getNeurons();
        return (Neuron)Iterables.getOnlyElement((Iterable)neurons);
    }

    private static NeuralNetwork.ActivationFunction parseActivationFunction(String activation) {
        switch (activation) {
            case "identity": {
                return NeuralNetwork.ActivationFunction.IDENTITY;
            }
            case "logistic": {
                return NeuralNetwork.ActivationFunction.LOGISTIC;
            }
            case "relu": {
                return NeuralNetwork.ActivationFunction.RECTIFIER;
            }
            case "tanh": {
                return NeuralNetwork.ActivationFunction.TANH;
            }
        }
        throw new IllegalArgumentException(activation);
    }
}

