/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;
import org.jpmml.sparkml.ClassificationModelConverter;
import org.jpmml.sparkml.SchemaUtil;
import org.jpmml.sparkml.SparkMLEncoder;

public class MultilayerPerceptronClassificationModelConverter
extends ClassificationModelConverter<MultilayerPerceptronClassificationModel> {
    public MultilayerPerceptronClassificationModelConverter(MultilayerPerceptronClassificationModel model) {
        super(model);
    }

    @Override
    public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
        MultilayerPerceptronClassificationModel model = (MultilayerPerceptronClassificationModel)this.getTransformer();
        List<OutputField> result = super.registerOutputFields(label, encoder);
        if (!(model instanceof HasProbabilityCol)) {
            CategoricalLabel categoricalLabel = (CategoricalLabel)label;
            result = new ArrayList<OutputField>(result);
            result.addAll(ModelUtil.createProbabilityFields((DataType)DataType.DOUBLE, (List)categoricalLabel.getValues()));
        }
        return result;
    }

    public NeuralNetwork encodeModel(Schema schema) {
        MultilayerPerceptronClassificationModel model = (MultilayerPerceptronClassificationModel)this.getTransformer();
        int[] layers = model.layers();
        Vector weights = model.weights();
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        List features = schema.getFeatures();
        SchemaUtil.checkSize(layers[layers.length - 1], categoricalLabel);
        SchemaUtil.checkSize(layers[0], features);
        NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs((List)features, (DataType)DataType.DOUBLE);
        List entities = neuralInputs.getNeuralInputs();
        ArrayList<NeuralLayer> neuralLayers = new ArrayList<NeuralLayer>();
        int weightPos = 0;
        for (int layer = 1; layer < layers.length; ++layer) {
            List<Double> weightVector;
            int column;
            NeuralLayer neuralLayer = new NeuralLayer();
            int rows = entities.size();
            int columns = layers[layer];
            ArrayList weightMatrix = new ArrayList();
            for (column = 0; column < columns; ++column) {
                weightVector = new ArrayList<Double>();
                for (int row = 0; row < rows; ++row) {
                    weightVector.add(weights.apply(weightPos + row * columns + column));
                }
                weightMatrix.add(weightVector);
            }
            weightPos += rows * columns;
            for (column = 0; column < columns; ++column) {
                weightVector = (List)weightMatrix.get(column);
                Double bias = weights.apply(weightPos);
                Neuron neuron = NeuralNetworkUtil.createNeuron((List)entities, weightVector, (Double)bias).setId(String.valueOf(layer) + "/" + String.valueOf(column + 1));
                neuralLayer.addNeurons(new Neuron[]{neuron});
                ++weightPos;
            }
            if (layer == layers.length - 1) {
                neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY).setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
            }
            neuralLayers.add(neuralLayer);
            entities = neuralLayer.getNeurons();
        }
        if (weightPos != weights.size()) {
            throw new IllegalArgumentException();
        }
        NeuralNetwork neuralNetwork = new NeuralNetwork(MiningFunction.CLASSIFICATION, NeuralNetwork.ActivationFunction.LOGISTIC, ModelUtil.createMiningSchema((Label)categoricalLabel), neuralInputs, neuralLayers).setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs((List)entities, (CategoricalLabel)categoricalLabel));
        return neuralNetwork;
    }
}

