/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.manager;

import java.util.List;
import org.dmg.pmml.ActivationFunctionType;
import org.dmg.pmml.Connection;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.NeuralInput;
import org.dmg.pmml.NeuralInputs;
import org.dmg.pmml.NeuralLayer;
import org.dmg.pmml.NeuralNetwork;
import org.dmg.pmml.NeuralOutput;
import org.dmg.pmml.NeuralOutputs;
import org.dmg.pmml.Neuron;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.jpmml.manager.ModelManager;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class NeuralNetworkManager
extends ModelManager<NeuralNetwork> {
    private NeuralNetwork neuralNetwork = null;
    private int neuronCount = 0;

    public NeuralNetworkManager() {
    }

    public NeuralNetworkManager(PMML pmml) {
        this(pmml, NeuralNetworkManager.find(pmml.getContent(), NeuralNetwork.class));
    }

    public NeuralNetworkManager(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml);
        this.neuralNetwork = neuralNetwork;
        if (this.neuralNetwork != null) {
            this.neuronCount = this.getNeuronCount();
        }
    }

    @Override
    public String getSummary() {
        return "Neural network";
    }

    @Override
    public NeuralNetwork getModel() {
        NeuralNetworkManager.ensureNotNull(this.neuralNetwork);
        return this.neuralNetwork;
    }

    public NeuralNetwork createModel(MiningFunctionType miningFunction, ActivationFunctionType activationFunction) {
        NeuralNetworkManager.ensureNull(this.neuralNetwork);
        this.neuralNetwork = new NeuralNetwork(new MiningSchema(), new NeuralInputs(), miningFunction, activationFunction);
        this.getModels().add((Model)this.neuralNetwork);
        return this.neuralNetwork;
    }

    public List<NeuralInput> getNeuralInputs() {
        NeuralNetwork neuralNetwork = this.getModel();
        return neuralNetwork.getNeuralInputs().getNeuralInputs();
    }

    public NeuralInput addNeuralInput(NormContinuous normContinuous) {
        DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE);
        derivedField.setExpression((Expression)normContinuous);
        NeuralInput neuralInput = new NeuralInput(derivedField, this.nextId());
        this.getNeuralInputs().add(neuralInput);
        return neuralInput;
    }

    public List<NeuralLayer> getNeuralLayers() {
        NeuralNetwork neuralNetwork = this.getModel();
        return neuralNetwork.getNeuralLayers();
    }

    public NeuralLayer addNeuralLayer() {
        NeuralLayer neuralLayer = new NeuralLayer();
        this.getNeuralLayers().add(neuralLayer);
        return neuralLayer;
    }

    public int getNeuronCount() {
        int count = 0;
        count += this.getNeuralInputs().size();
        List<NeuralLayer> neuralLayers = this.getNeuralLayers();
        for (NeuralLayer neuralLayer : neuralLayers) {
            count += neuralLayer.getNeurons().size();
        }
        return count;
    }

    public Neuron addNeuron(NeuralLayer neuralLayer, Double bias) {
        Neuron neuron = new Neuron(this.nextId());
        neuron.setBias(bias);
        neuralLayer.getNeurons().add(neuron);
        return neuron;
    }

    public static void addConnection(NeuralInput from, Neuron to, double weight) {
        Connection connection = new Connection(from.getId(), weight);
        to.getConnections().add(connection);
    }

    public static void addConnection(Neuron from, Neuron to, double weight) {
        Connection connection = new Connection(from.getId(), weight);
        to.getConnections().add(connection);
    }

    public List<NeuralOutput> getOrCreateNeuralOutputs() {
        NeuralNetwork neuralNetwork = this.getModel();
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        if (neuralOutputs == null) {
            neuralOutputs = new NeuralOutputs();
            neuralNetwork.setNeuralOutputs(neuralOutputs);
        }
        return neuralOutputs.getNeuralOutputs();
    }

    public NeuralOutput addNeuralOutput(Neuron neuron, NormContinuous normCountinuous) {
        DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE);
        derivedField.setExpression((Expression)normCountinuous);
        NeuralOutput output = new NeuralOutput(derivedField, neuron.getId());
        this.getOrCreateNeuralOutputs().add(output);
        return output;
    }

    private String nextId() {
        return String.valueOf(this.neuronCount++);
    }
}

