/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.sparkml.model;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.ActivationFunctionType;
import org.dmg.pmml.Connection;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NeuralInput;
import org.dmg.pmml.NeuralInputs;
import org.dmg.pmml.NeuralLayer;
import org.dmg.pmml.NeuralNetwork;
import org.dmg.pmml.NeuralOutput;
import org.dmg.pmml.NeuralOutputs;
import org.dmg.pmml.Neuron;
import org.dmg.pmml.NnNormalizationMethodType;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.OpType;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.ClassificationModelConverter;

public class MultilayerPerceptronClassificationModelConverter
extends ClassificationModelConverter<MultilayerPerceptronClassificationModel> {
    public MultilayerPerceptronClassificationModelConverter(MultilayerPerceptronClassificationModel model) {
        super(model);
    }

    public NeuralNetwork encodeModel(Schema schema) {
        MultilayerPerceptronClassificationModel model = (MultilayerPerceptronClassificationModel)this.getTransformer();
        int[] layers = model.layers();
        Vector weights = model.weights();
        List features = schema.getFeatures();
        if (features.size() != layers[0]) {
            throw new IllegalArgumentException();
        }
        FieldName targetField = schema.getTargetField();
        List targetCategories = schema.getTargetCategories();
        if (targetCategories.size() != layers[layers.length - 1]) {
            throw new IllegalArgumentException();
        }
        NeuralInputs neuralInputs = new NeuralInputs();
        for (int column = 0; column < features.size(); ++column) {
            Feature feature = (Feature)features.get(column);
            DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE);
            if (feature instanceof ContinuousFeature) {
                ContinuousFeature continuousFeature = (ContinuousFeature)feature;
                derivedField.setExpression((Expression)new FieldRef(continuousFeature.getName()));
            } else if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                derivedField.setExpression((Expression)new NormDiscrete(binaryFeature.getName(), binaryFeature.getValue()));
            } else {
                throw new IllegalArgumentException();
            }
            NeuralInput neuralInput = new NeuralInput().setId("0/" + String.valueOf(column + 1)).setDerivedField(derivedField);
            neuralInputs.addNeuralInputs(new NeuralInput[]{neuralInput});
        }
        ArrayList<Neuron> entities = neuralInputs.getNeuralInputs();
        ArrayList<NeuralLayer> neuralLayers = new ArrayList<NeuralLayer>();
        int weightPos = 0;
        for (int i = 1; i < layers.length; ++i) {
            ArrayList<Neuron> neurons = new ArrayList<Neuron>();
            int rows = entities.size();
            int columns = layers[i];
            for (int column = 0; column < columns; ++column) {
                Neuron neuron = new Neuron().setId(i + "/" + String.valueOf(column + 1));
                for (int row = 0; row < rows; ++row) {
                    Entity entity = (Entity)entities.get(row);
                    Connection connection = new Connection().setFrom(entity.getId()).setWeight(weights.apply(weightPos + row * columns + column));
                    neuron.addConnections(new Connection[]{connection});
                }
                neurons.add(neuron);
            }
            weightPos += rows * columns;
            for (Neuron neuron : neurons) {
                neuron.setBias(Double.valueOf(weights.apply(weightPos)));
                ++weightPos;
            }
            NeuralLayer neuralLayer = new NeuralLayer(neurons);
            if (i == layers.length - 1) {
                neuralLayer.setActivationFunction(ActivationFunctionType.IDENTITY).setNormalizationMethod(NnNormalizationMethodType.SOFTMAX);
            }
            neuralLayers.add(neuralLayer);
            entities = neurons;
        }
        if (weightPos != weights.size()) {
            throw new IllegalArgumentException();
        }
        NeuralOutputs neuralOutputs = new NeuralOutputs();
        for (int column = 0; column < targetCategories.size(); ++column) {
            String targetCategory = (String)targetCategories.get(column);
            Entity entity = (Entity)entities.get(column);
            DerivedField derivedField = new DerivedField(OpType.CATEGORICAL, DataType.STRING).setExpression((Expression)new NormDiscrete(targetField, targetCategory));
            NeuralOutput neuralOutput = new NeuralOutput().setOutputNeuron(entity.getId()).setDerivedField(derivedField);
            neuralOutputs.addNeuralOutputs(new NeuralOutput[]{neuralOutput});
        }
        NeuralNetwork neuralNetwork = new NeuralNetwork(MiningFunctionType.CLASSIFICATION, ActivationFunctionType.LOGISTIC, ModelUtil.createMiningSchema((Schema)schema), neuralInputs, neuralLayers).setNeuralOutputs(neuralOutputs).setOutput(ModelUtil.createProbabilityOutput((Schema)schema));
        return neuralNetwork;
    }
}

