/*
 * Decompiled with CFR 0.152.
 */
package sklearn.neural_network;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.neural_network.Connection;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.HasArray;

public class MultilayerPerceptronUtil {
    private MultilayerPerceptronUtil() {
    }

    public static int getNumberOfFeatures(List<? extends HasArray> coefs) {
        HasArray input = coefs.get(0);
        int[] shape = input.getArrayShape();
        if (shape.length != 2) {
            throw new IllegalArgumentException();
        }
        return shape[0];
    }

    public static NeuralNetwork encodeNeuralNetwork(MiningFunction miningFunction, String activation, List<? extends HasArray> coefs, List<? extends HasArray> intercepts, Schema schema) {
        NeuralNetwork.ActivationFunction activationFunction = MultilayerPerceptronUtil.parseActivationFunction(activation);
        ClassDictUtil.checkSize(coefs, intercepts);
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs((List)features, (DataType)DataType.DOUBLE);
        List entities = neuralInputs.getNeuralInputs();
        ArrayList<NeuralLayer> neuralLayers = new ArrayList<NeuralLayer>();
        for (int layer = 0; layer < coefs.size(); ++layer) {
            HasArray coef = coefs.get(layer);
            HasArray intercept = intercepts.get(layer);
            int[] shape = coef.getArrayShape();
            int rows = shape[0];
            int columns = shape[1];
            NeuralLayer neuralLayer = new NeuralLayer();
            List<?> coefMatrix = coef.getArrayContent();
            List<?> interceptVector = intercept.getArrayContent();
            for (int column = 0; column < columns; ++column) {
                List weights = CMatrixUtil.getColumn(coefMatrix, (int)rows, (int)columns, (int)column);
                Double bias = ValueUtil.asDouble((Number)((Number)interceptVector.get(column)));
                Neuron neuron = NeuralNetworkUtil.createNeuron((List)entities, (List)weights, (Double)bias).setId(String.valueOf(layer + 1) + "/" + String.valueOf(column + 1));
                neuralLayer.addNeurons(new Neuron[]{neuron});
            }
            if (layer == coefs.size() - 1) {
                neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY);
                switch (miningFunction) {
                    case REGRESSION: {
                        break;
                    }
                    case CLASSIFICATION: {
                        CategoricalLabel categoricalLabel = (CategoricalLabel)label;
                        if (categoricalLabel.size() == 2) {
                            neuralLayers.add(neuralLayer);
                            neuralLayer = MultilayerPerceptronUtil.encodeLogisticTransform(MultilayerPerceptronUtil.getOnlyNeuron(neuralLayer));
                            neuralLayers.add(neuralLayer);
                            neuralLayer = MultilayerPerceptronUtil.encodeLabelBinarizerTransform(MultilayerPerceptronUtil.getOnlyNeuron(neuralLayer));
                            break;
                        }
                        if (categoricalLabel.size() > 2) {
                            neuralLayer.setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
                            break;
                        }
                        throw new IllegalArgumentException();
                    }
                }
            }
            entities = neuralLayer.getNeurons();
            neuralLayers.add(neuralLayer);
        }
        NeuralOutputs neuralOutputs = null;
        switch (miningFunction) {
            case REGRESSION: {
                neuralOutputs = NeuralNetworkUtil.createRegressionNeuralOutputs((List)entities, (ContinuousLabel)((ContinuousLabel)label));
                break;
            }
            case CLASSIFICATION: {
                neuralOutputs = NeuralNetworkUtil.createClassificationNeuralOutputs((List)entities, (CategoricalLabel)((CategoricalLabel)label));
                break;
            }
        }
        NeuralNetwork neuralNetwork = new NeuralNetwork(miningFunction, activationFunction, ModelUtil.createMiningSchema((Label)label), neuralInputs, neuralLayers).setNeuralOutputs(neuralOutputs);
        return neuralNetwork;
    }

    private static NeuralLayer encodeLogisticTransform(Neuron input) {
        NeuralLayer neuralLayer = new NeuralLayer().setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
        Neuron neuron = new Neuron().setId("logistic/1").setBias(Double.valueOf(0.0)).addConnections(new Connection[]{new Connection(input.getId(), 1.0)});
        neuralLayer.addNeurons(new Neuron[]{neuron});
        return neuralLayer;
    }

    private static NeuralLayer encodeLabelBinarizerTransform(Neuron input) {
        NeuralLayer neuralLayer = new NeuralLayer().setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY);
        Neuron noEventNeuron = new Neuron().setId("event/false").setBias(Double.valueOf(1.0)).addConnections(new Connection[]{new Connection(input.getId(), -1.0)});
        Neuron eventNeuron = new Neuron().setId("event/true").setBias(Double.valueOf(0.0)).addConnections(new Connection[]{new Connection(input.getId(), 1.0)});
        neuralLayer.addNeurons(new Neuron[]{noEventNeuron, eventNeuron});
        return neuralLayer;
    }

    private static Neuron getOnlyNeuron(NeuralLayer neuralLayer) {
        List neurons = neuralLayer.getNeurons();
        return (Neuron)Iterables.getOnlyElement((Iterable)neurons);
    }

    private static NeuralNetwork.ActivationFunction parseActivationFunction(String activation) {
        switch (activation) {
            case "identity": {
                return NeuralNetwork.ActivationFunction.IDENTITY;
            }
            case "logistic": {
                return NeuralNetwork.ActivationFunction.LOGISTIC;
            }
            case "relu": {
                return NeuralNetwork.ActivationFunction.RECTIFIER;
            }
            case "tanh": {
                return NeuralNetwork.ActivationFunction.TANH;
            }
        }
        throw new IllegalArgumentException(activation);
    }
}

