/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.manager;

import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import java.util.List;
import org.dmg.pmml.ActivationFunctionType;
import org.dmg.pmml.Connection;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.NeuralInput;
import org.dmg.pmml.NeuralInputs;
import org.dmg.pmml.NeuralLayer;
import org.dmg.pmml.NeuralNetwork;
import org.dmg.pmml.NeuralOutput;
import org.dmg.pmml.NeuralOutputs;
import org.dmg.pmml.Neuron;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.jpmml.manager.HasEntityRegistry;
import org.jpmml.manager.ModelManager;

public class NeuralNetworkManager
extends ModelManager<NeuralNetwork>
implements HasEntityRegistry<Entity> {
    private NeuralNetwork neuralNetwork = null;

    public NeuralNetworkManager() {
    }

    public NeuralNetworkManager(PMML pmml) {
        this(pmml, NeuralNetworkManager.find(pmml.getContent(), NeuralNetwork.class));
    }

    public NeuralNetworkManager(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml);
        this.neuralNetwork = neuralNetwork;
    }

    @Override
    public String getSummary() {
        return "Neural network";
    }

    @Override
    public NeuralNetwork getModel() {
        Preconditions.checkState(this.neuralNetwork != null);
        return this.neuralNetwork;
    }

    public NeuralNetwork createModel(MiningFunctionType miningFunction, ActivationFunctionType activationFunction) {
        Preconditions.checkState(this.neuralNetwork == null);
        this.neuralNetwork = new NeuralNetwork(new MiningSchema(), new NeuralInputs(), miningFunction, activationFunction);
        this.getModels().add(this.neuralNetwork);
        return this.neuralNetwork;
    }

    public List<NeuralInput> getNeuralInputs() {
        NeuralNetwork neuralNetwork = this.getModel();
        return neuralNetwork.getNeuralInputs().getNeuralInputs();
    }

    public NeuralInput addNeuralInput(String id, NormContinuous normContinuous) {
        DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE);
        derivedField.setExpression(normContinuous);
        NeuralInput neuralInput = new NeuralInput(derivedField, id);
        this.getNeuralInputs().add(neuralInput);
        return neuralInput;
    }

    public List<NeuralLayer> getNeuralLayers() {
        NeuralNetwork neuralNetwork = this.getModel();
        return neuralNetwork.getNeuralLayers();
    }

    public NeuralLayer addNeuralLayer() {
        NeuralLayer neuralLayer = new NeuralLayer();
        this.getNeuralLayers().add(neuralLayer);
        return neuralLayer;
    }

    @Override
    public BiMap<String, Entity> getEntityRegistry() {
        HashBiMap<String, Entity> result = HashBiMap.create();
        List<NeuralInput> neuralInputs = this.getNeuralInputs();
        for (NeuralInput neuralInput : neuralInputs) {
            NeuralNetworkManager.putEntity(neuralInput, result);
        }
        List<NeuralLayer> neuralLayers = this.getNeuralLayers();
        for (NeuralLayer neuralLayer : neuralLayers) {
            List<Neuron> neurons = neuralLayer.getNeurons();
            for (Neuron neuron : neurons) {
                NeuralNetworkManager.putEntity(neuron, result);
            }
        }
        return result;
    }

    public static Neuron addNeuron(NeuralLayer neuralLayer, String id, Double bias) {
        Neuron neuron = new Neuron(id);
        neuron.setBias(bias);
        neuralLayer.getNeurons().add(neuron);
        return neuron;
    }

    public static void addConnection(NeuralInput from, Neuron to, double weight) {
        Connection connection = new Connection(from.getId(), weight);
        to.getConnections().add(connection);
    }

    public static void addConnection(Neuron from, Neuron to, double weight) {
        Connection connection = new Connection(from.getId(), weight);
        to.getConnections().add(connection);
    }

    public List<NeuralOutput> getOrCreateNeuralOutputs() {
        NeuralNetwork neuralNetwork = this.getModel();
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        if (neuralOutputs == null) {
            neuralOutputs = new NeuralOutputs();
            neuralNetwork.setNeuralOutputs(neuralOutputs);
        }
        return neuralOutputs.getNeuralOutputs();
    }

    public NeuralOutput addNeuralOutput(Neuron neuron, NormContinuous normCountinuous) {
        DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE);
        derivedField.setExpression(normCountinuous);
        NeuralOutput output = new NeuralOutput(derivedField, neuron.getId());
        this.getOrCreateNeuralOutputs().add(output);
        return output;
    }
}

