package org.deeplearning4j.nn.multilayer;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.deeplearning4j.nn.layers.factory.LayerFactories;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.GradientAdjustment;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.util.MultiLayerUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/multilayer/MultiLayerNetwork.class */
public class MultiLayerNetwork implements Serializable, Classifier {
    private static final Logger log = LoggerFactory.getLogger(MultiLayerNetwork.class);
    private static final long serialVersionUID = -5029161847383716484L;
    protected Layer[] layers;
    protected INDArray input;
    protected INDArray labels;
    protected boolean initCalled;
    private List<IterationListener> listeners;
    protected NeuralNetConfiguration defaultConfiguration;
    protected MultiLayerConfiguration layerWiseConfigurations;
    protected INDArray mask;

    public MultiLayerNetwork(MultiLayerConfiguration multiLayerConfiguration) {
        this.initCalled = false;
        this.listeners = new ArrayList();
        this.layerWiseConfigurations = multiLayerConfiguration;
        this.defaultConfiguration = multiLayerConfiguration.getConf(0);
    }

    public MultiLayerNetwork(String str, INDArray iNDArray) {
        this(MultiLayerConfiguration.fromJson(str));
        init();
        setParameters(iNDArray);
    }

    public MultiLayerNetwork(MultiLayerConfiguration multiLayerConfiguration, INDArray iNDArray) {
        this(multiLayerConfiguration);
        init();
        setParameters(iNDArray);
    }

    protected void intializeConfigurations() {
        if (this.layerWiseConfigurations == null) {
            this.layerWiseConfigurations = new MultiLayerConfiguration.Builder().build();
        }
        if (this.layers == null) {
            this.layers = new Layer[getnLayers()];
        }
        if (this.defaultConfiguration == null) {
            this.defaultConfiguration = new NeuralNetConfiguration.Builder().build();
        }
        if (this.layerWiseConfigurations == null || this.layerWiseConfigurations.getConfs().isEmpty()) {
            for (int i = 0; i < this.layerWiseConfigurations.getHiddenLayerSizes().length + 1; i++) {
                this.layerWiseConfigurations.getConfs().add(this.defaultConfiguration.m17clone());
            }
        }
    }

    public void pretrain(DataSetIterator dataSetIterator) {
        if (this.layerWiseConfigurations.isPretrain()) {
            for (int i = 0; i < getnLayers(); i++) {
                if (i == 0) {
                    while (dataSetIterator.hasNext()) {
                        DataSet dataSet = (DataSet) dataSetIterator.next();
                        this.input = dataSet.getFeatureMatrix();
                        if (getInput() == null || getLayers() == null) {
                            setInput(this.input);
                            initializeLayers(this.input);
                        } else {
                            setInput(this.input);
                        }
                        getLayers()[i].fit(dataSet.getFeatureMatrix());
                        log.info("Training on layer " + (i + 1) + " with " + this.input.slices() + " examples");
                    }
                    dataSetIterator.reset();
                } else {
                    while (dataSetIterator.hasNext()) {
                        INDArray featureMatrix = ((DataSet) dataSetIterator.next()).getFeatureMatrix();
                        for (int i2 = 1; i2 <= i; i2++) {
                            featureMatrix = activationFromPrevLayer(i2 - 1, featureMatrix);
                        }
                        log.info("Training on layer " + (i + 1) + " with " + featureMatrix.slices() + " examples");
                        getLayers()[i].fit(featureMatrix);
                    }
                    dataSetIterator.reset();
                }
            }
        }
    }

    public void pretrain(INDArray iNDArray) {
        if (this.layerWiseConfigurations.isPretrain()) {
            INDArray iNDArray2 = null;
            int i = 0;
            while (i < getnLayers() - 1) {
                iNDArray2 = i == 0 ? iNDArray : activationFromPrevLayer(i - 1, iNDArray2);
                log.info("Training on layer " + (i + 1) + " with " + iNDArray2.slices() + " examples");
                getLayers()[i].fit(iNDArray2);
                i++;
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int batchSize() {
        return this.input.slices();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public NeuralNetConfiguration conf() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray input() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void validateInput() {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public ConvexOptimizer getOptimizer() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void initParams() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamTable(Map<String, INDArray> map) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParam(String str, INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray transform(INDArray iNDArray) {
        return output(iNDArray);
    }

    public MultiLayerConfiguration getLayerWiseConfigurations() {
        return this.layerWiseConfigurations;
    }

    public void setLayerWiseConfigurations(MultiLayerConfiguration multiLayerConfiguration) {
        this.layerWiseConfigurations = multiLayerConfiguration;
    }

    public void initializeLayers(INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalArgumentException("Unable to initialize neuralNets with empty input");
        }
        int[] hiddenLayerSizes = getLayerWiseConfigurations().getHiddenLayerSizes();
        if (iNDArray.shape().length == 2) {
            for (int i : hiddenLayerSizes) {
                if (i < 1) {
                    throw new IllegalArgumentException("All hidden layer sizes must be >= 1");
                }
            }
        }
        this.input = iNDArray;
        if (this.initCalled) {
            return;
        }
        init();
    }

    public void init() {
        if (this.layerWiseConfigurations == null || this.layers == null) {
            intializeConfigurations();
        }
        INDArray input = input();
        int i = 0;
        if (getnLayers() < 1) {
            throw new IllegalStateException("Unable to createComplex network neuralNets; number specified is less than 1");
        }
        int[] hiddenLayerSizes = this.layerWiseConfigurations.getHiddenLayerSizes();
        int i2 = 1;
        if (this.layers == null || this.layers[0] == null) {
            if (this.layers == null) {
                this.layers = new Layer[getnLayers()];
            }
            for (int i3 = 0; i3 < getnLayers(); i3++) {
                if (i3 == 0) {
                    i = this.layerWiseConfigurations.getConf(0).getnIn();
                    if (this.input == null) {
                        this.input = Nd4j.ones(i);
                        input = this.input;
                    }
                } else if (LayerFactories.typeForFactory(this.layerWiseConfigurations.getConf(i3)) == Layer.Type.FEED_FORWARD) {
                    i = hiddenLayerSizes[i2 - 1];
                }
                if (i3 == 0) {
                    if (LayerFactories.typeForFactory(this.layerWiseConfigurations.getConf(i3)) == Layer.Type.FEED_FORWARD) {
                        this.layerWiseConfigurations.getConf(i3).setnIn(i);
                        this.layerWiseConfigurations.getConf(i3).setnOut(hiddenLayerSizes[i3]);
                    }
                    this.layerWiseConfigurations.getConf(i3).setnIn(i);
                    this.layerWiseConfigurations.getConf(i3).setnOut(hiddenLayerSizes[i3]);
                    this.layers[i3] = LayerFactories.getFactory(this.layerWiseConfigurations.getConf(i3)).create(this.layerWiseConfigurations.getConf(i3));
                } else if (i3 < getLayers().length - 1) {
                    if (this.input != null) {
                        input = activationFromPrevLayer(i3 - 1, input);
                    }
                    if (LayerFactories.typeForFactory(this.layerWiseConfigurations.getConf(i3)) == Layer.Type.FEED_FORWARD) {
                        i2++;
                        this.layerWiseConfigurations.getConf(i3).setnIn(input.columns());
                        this.layerWiseConfigurations.getConf(i3).setnOut(hiddenLayerSizes[i3]);
                    }
                    this.layers[i3] = LayerFactories.getFactory(this.layerWiseConfigurations.getConf(i3)).create(this.layerWiseConfigurations.getConf(i3), this.listeners);
                }
            }
            NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(this.layerWiseConfigurations.getConfs().size() - 1);
            if (LayerFactories.typeForFactory(this.layerWiseConfigurations.getConf(this.layerWiseConfigurations.getConfs().size() - 1)) == Layer.Type.FEED_FORWARD) {
                conf.setnIn(hiddenLayerSizes[hiddenLayerSizes.length - 1]);
            }
            this.layers[this.layers.length - 1] = LayerFactories.getFactory(conf).create(conf);
            this.initCalled = true;
            initMask();
        }
    }

    public INDArray activate() {
        return getLayers()[getLayers().length - 1].activate();
    }

    public INDArray activate(int i) {
        return getLayers()[i].activate();
    }

    public INDArray activate(int i, INDArray iNDArray) {
        return getLayers()[i].activate(iNDArray);
    }

    public void initialize(DataSet dataSet) {
        setInput(dataSet.getFeatureMatrix());
        feedForward(getInput());
        this.labels = dataSet.getLabels();
        if (getOutputLayer() instanceof OutputLayer) {
            ((OutputLayer) getOutputLayer()).setLabels(this.labels);
        }
    }

    public INDArray activationFromPrevLayer(int i, INDArray iNDArray) {
        if (getLayerWiseConfigurations().getInputPreProcess(i) != null) {
            iNDArray = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(iNDArray);
        }
        INDArray activate = this.layers[i].activate(iNDArray);
        return (getLayerWiseConfigurations().getProcessors() == null || getLayerWiseConfigurations().getPreProcessor(i) == null) ? activate : getLayerWiseConfigurations().getPreProcessor(i).preProcess(activate);
    }

    public List<INDArray> feedForward() {
        INDArray iNDArray = this.input;
        ArrayList arrayList = new ArrayList();
        arrayList.add(iNDArray);
        for (int i = 0; i < this.layers.length; i++) {
            iNDArray = activationFromPrevLayer(i, iNDArray);
            applyDropConnectIfNecessary(iNDArray);
            arrayList.add(iNDArray);
        }
        return arrayList;
    }

    public List<INDArray> feedForward(INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        if (getLayerWiseConfigurations().getInputPreProcess(0) != null) {
            this.input = getLayerWiseConfigurations().getInputPreProcess(0).preProcess(iNDArray);
        } else {
            this.input = iNDArray;
        }
        return feedForward();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        DefaultGradient defaultGradient = new DefaultGradient();
        for (int i = 0; i < this.layers.length; i += 2) {
            defaultGradient.gradientForVariable().put(String.valueOf(i), this.layers[i].gradient().gradient());
        }
        return defaultGradient;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(getOutputLayer().score()));
    }

    protected void applyDropConnectIfNecessary(INDArray iNDArray) {
        if (this.layerWiseConfigurations.isUseDropConnect()) {
            INDArray valueArrayOf = Nd4j.valueArrayOf(iNDArray.slices(), iNDArray.columns(), 0.5d);
            iNDArray.muli(Nd4j.getDistributions().createBinomial(1, valueArrayOf).sample(valueArrayOf.shape()));
            if (this.defaultConfiguration.getL2() > 0.0d) {
                iNDArray.muli(Double.valueOf(this.defaultConfiguration.getL2()));
            }
        }
    }

    protected List<INDArray> computeDeltasR(INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        INDArray[] iNDArrayArr = new INDArray[getnLayers() + 1];
        List<INDArray> feedForward = feedForward();
        List<INDArray> feedForwardR = feedForwardR(feedForward, iNDArray);
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < getLayers().length; i++) {
            arrayList2.add(getLayers()[i].getParam(DefaultParamInitializer.WEIGHT_KEY));
            arrayList3.add(getLayers()[i].getParam("b"));
            arrayList4.add(getLayers()[i].conf().getActivationFunction());
        }
        INDArray divi = feedForwardR.get(feedForwardR.size() - 1).divi(Double.valueOf(this.input.slices()));
        LinAlgExceptions.assertValidNum(divi);
        for (int i2 = getnLayers() - 1; i2 >= 0; i2--) {
            iNDArrayArr[i2] = feedForward.get(i2).transpose().mmul(divi);
            applyDropConnectIfNecessary(iNDArrayArr[i2]);
            if (i2 > 0) {
                divi = divi.mmul(((INDArray) arrayList2.get(i2)).addRowVector((INDArray) arrayList3.get(i2)).transpose()).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform((String) arrayList4.get(i2 - 1), feedForward.get(i2)).derivative()));
            }
        }
        for (int i3 = 0; i3 < iNDArrayArr.length - 1; i3++) {
            if (!this.defaultConfiguration.isConstrainGradientToUnitNorm()) {
                arrayList.add(iNDArrayArr[i3]);
            } else if (iNDArrayArr[i3].sum(Integer.MAX_VALUE).getDouble(0) > 0.0d) {
                arrayList.add(iNDArrayArr[i3].div(iNDArrayArr[i3].norm2(Integer.MAX_VALUE)));
            } else {
                arrayList.add(iNDArrayArr[i3]);
            }
            LinAlgExceptions.assertValidNum((INDArray) arrayList.get(i3));
        }
        return arrayList;
    }

    public void dampingUpdate(double d, double d2, double d3) {
        if (d < 0.25d || Double.isNaN(d)) {
            this.layerWiseConfigurations.setDampingFactor(getLayerWiseConfigurations().getDampingFactor() * d2);
        } else if (d > 0.75d) {
            this.layerWiseConfigurations.setDampingFactor(getLayerWiseConfigurations().getDampingFactor() * d3);
        }
    }

    public double reductionRatio(INDArray iNDArray, double d, double d2, INDArray iNDArray2) {
        double dampingFactor = this.layerWiseConfigurations.getDampingFactor();
        this.layerWiseConfigurations.setDampingFactor(0.0d);
        INDArray backPropRGradient = getBackPropRGradient(iNDArray);
        backPropRGradient.muli(Double.valueOf(0.5d)).muli(iNDArray.mul(backPropRGradient)).sum(0);
        backPropRGradient.subi(iNDArray2.mul(iNDArray).sum(0));
        double doubleValue = (d - d2) / ((Double) backPropRGradient.getScalar(0).element()).doubleValue();
        this.layerWiseConfigurations.setDampingFactor(dampingFactor);
        if (d2 - d > 0.0d) {
            return Double.NEGATIVE_INFINITY;
        }
        return doubleValue;
    }

    protected List<Pair<INDArray, INDArray>> computeDeltas2() {
        ArrayList arrayList = new ArrayList();
        List<INDArray> feedForward = feedForward();
        INDArray[] iNDArrayArr = new INDArray[feedForward.size() - 1];
        INDArray[] iNDArrayArr2 = new INDArray[feedForward.size() - 1];
        INDArray div = feedForward.get(feedForward.size() - 1).sub(this.labels).div(Integer.valueOf(this.labels.slices()));
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < getLayers().length; i++) {
            arrayList2.add(getLayers()[i].getParam(DefaultParamInitializer.WEIGHT_KEY));
            arrayList3.add(getLayers()[i].getParam("b"));
            arrayList4.add(getLayers()[i].conf().getActivationFunction());
        }
        for (int size = arrayList2.size() - 1; size >= 0; size--) {
            iNDArrayArr[size] = feedForward.get(size).transpose().mmul(div);
            iNDArrayArr2[size] = Transforms.pow(feedForward.get(size).transpose(), 2).mmul(Transforms.pow(div, 2)).muli(Integer.valueOf(this.labels.slices()));
            applyDropConnectIfNecessary(iNDArrayArr[size]);
            if (size > 0) {
                div = div.mmul(((INDArray) arrayList2.get(size)).transpose()).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform((String) arrayList4.get(size - 1), feedForward.get(size)).derivative()));
            }
        }
        for (int i2 = 0; i2 < iNDArrayArr.length; i2++) {
            if (this.defaultConfiguration.isConstrainGradientToUnitNorm()) {
                arrayList.add(new Pair(iNDArrayArr[i2].divi(iNDArrayArr[i2].norm2(Integer.MAX_VALUE)), iNDArrayArr2[i2]));
            } else {
                arrayList.add(new Pair(iNDArrayArr[i2], iNDArrayArr2[i2]));
            }
        }
        return arrayList;
    }

    public INDArray getBackPropRGradient(INDArray iNDArray) {
        return pack(backPropGradientR(iNDArray));
    }

    public Pair<INDArray, INDArray> getBackPropGradient2() {
        List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> backPropGradient2 = backPropGradient2();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < backPropGradient2.size(); i++) {
            arrayList.add(backPropGradient2.get(i).getFirst());
            arrayList2.add(backPropGradient2.get(i).getSecond());
        }
        return new Pair<>(pack(arrayList), pack(arrayList2));
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public MultiLayerNetwork m31clone() {
        try {
            MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) getClass().getDeclaredConstructor(MultiLayerConfiguration.class).newInstance(getLayerWiseConfigurations());
            multiLayerNetwork.update(this);
            return multiLayerNetwork;
        } catch (Exception e) {
            throw new IllegalStateException("Unable to cloe network");
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray params() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < getnLayers(); i++) {
            arrayList.add(this.layers[i].params());
        }
        return Nd4j.toFlattened(arrayList);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        setParameters(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams() {
        int i = 0;
        for (int i2 = 0; i2 < this.layers.length; i2++) {
            i += this.layers[i2].numParams();
        }
        return i;
    }

    public INDArray pack() {
        return params();
    }

    public INDArray pack(List<Pair<INDArray, INDArray>> list) {
        ArrayList arrayList = new ArrayList();
        for (Pair<INDArray, INDArray> pair : list) {
            arrayList.add(pair.getFirst());
            arrayList.add(pair.getSecond());
        }
        return Nd4j.toFlattened(arrayList);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double score(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        return score(dataSet.getFeatureMatrix(), dataSet.getLabels());
    }

    public List<Pair<INDArray, INDArray>> unPack(INDArray iNDArray) {
        if (iNDArray.slices() != 1) {
            iNDArray = iNDArray.reshape(1, iNDArray.length());
        }
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (int i2 = 0; i2 < this.layers.length; i2++) {
            int length = this.layers[i2].getParam(DefaultParamInitializer.WEIGHT_KEY).length() + this.layers[i2].getParam("b").length();
            INDArray iNDArray2 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(i, i + length)});
            INDArray iNDArray3 = iNDArray2.get(new NDArrayIndex[]{NDArrayIndex.interval(0, this.layers[i2].getParam(DefaultParamInitializer.WEIGHT_KEY).length())});
            INDArray iNDArray4 = iNDArray2.get(new NDArrayIndex[]{NDArrayIndex.interval(this.layers[i2].getParam(DefaultParamInitializer.WEIGHT_KEY).length(), iNDArray2.length())});
            if (iNDArray3.length() + iNDArray4.length() != length) {
                if (iNDArray4.length() != this.layers[i2].getParam("b").length()) {
                    throw new IllegalStateException("Hidden bias on layer " + i2 + " was off");
                }
                if (iNDArray3.length() != this.layers[i2].getParam(DefaultParamInitializer.WEIGHT_KEY).length()) {
                    throw new IllegalStateException("Weight portion on layer " + i2 + " was off");
                }
            }
            arrayList.add(new Pair(iNDArray3.reshape(this.layers[i2].getParam(DefaultParamInitializer.WEIGHT_KEY).slices(), this.layers[i2].getParam(DefaultParamInitializer.WEIGHT_KEY).columns()), iNDArray4.reshape(this.layers[i2].getParam("b").slices(), this.layers[i2].getParam("b").columns())));
            i += length;
        }
        return arrayList;
    }

    protected List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> backPropGradient2() {
        List<Pair<INDArray, INDArray>> computeDeltas2 = computeDeltas2();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < computeDeltas2.size(); i++) {
            INDArray first = computeDeltas2.get(i).getFirst();
            INDArray second = computeDeltas2.get(i).getSecond();
            if (i < this.layers.length && first.length() != this.layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            INDArray mean = computeDeltas2.get(i).getFirst().mean(0);
            INDArray mean2 = computeDeltas2.get(i).getSecond().mean(0);
            arrayList2.add(new Pair<>(first, mean));
            arrayList3.add(new Pair<>(second, mean2));
            if (i < this.layers.length && mean.length() != this.layers[i].getParam("b").length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }
            if (i == getLayers().length && mean.length() != getOutputLayer().getParam("b").length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }
        }
        INDArray pack = pack(arrayList2);
        INDArray pack2 = pack(arrayList3);
        INDArray params = params();
        if (this.mask == null) {
            initMask();
        }
        pack.addi(params.mul(Double.valueOf(this.defaultConfiguration.getL2())).muli(this.mask));
        pack2.addi(Transforms.pow(this.mask.mul(Double.valueOf(this.defaultConfiguration.getL2())).add(Nd4j.valueArrayOf(pack.slices(), pack.columns(), this.layerWiseConfigurations.getDampingFactor())), Double.valueOf(0.75d)));
        List<Pair<INDArray, INDArray>> unPack = unPack(pack);
        List<Pair<INDArray, INDArray>> unPack2 = unPack(pack2);
        for (int i2 = 0; i2 < unPack.size(); i2++) {
            arrayList.add(new Pair(unPack.get(i2), unPack2.get(i2)));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSetIterator dataSetIterator) {
        if (this.layerWiseConfigurations.isPretrain()) {
            pretrain(dataSetIterator);
            dataSetIterator.reset();
            finetune(dataSetIterator);
        }
        if (this.layerWiseConfigurations.isBackward()) {
            dataSetIterator.reset();
            while (dataSetIterator.hasNext()) {
                DataSet dataSet = (DataSet) dataSetIterator.next();
                doBackWard(dataSet.getFeatureMatrix(), dataSet.getLabels());
            }
        }
    }

    protected void doBackWard(INDArray iNDArray, INDArray iNDArray2) {
        setInput(iNDArray);
        this.labels = iNDArray2;
        feedForward();
        if (!(getOutputLayer() instanceof OutputLayer)) {
            log.warn("Warning: final layer isn't output layer. You can ignore this message if you just intend on using a a deep neural network with no output layer.");
            return;
        }
        OutputLayer outputLayer = (OutputLayer) getOutputLayer();
        if (iNDArray2 == null) {
            throw new IllegalStateException("No labels found");
        }
        outputLayer.setLabels(iNDArray2);
        Gradient[] gradientArr = new Gradient[getnLayers()];
        for (int i = 0; i < getLayerWiseConfigurations().getConf(0).getNumIterations(); i++) {
            List<INDArray> feedForward = feedForward();
            INDArray iNDArray3 = feedForward.get(feedForward.size() - 1);
            INDArray subi = iNDArray2.sub(iNDArray3).subi(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(getOutputLayer().conf().getActivationFunction(), iNDArray3).derivative()));
            Gradient defaultGradient = new DefaultGradient();
            defaultGradient.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, subi);
            gradientArr[gradientArr.length - 1] = defaultGradient;
            for (int i2 = getnLayers() - 1; i2 >= 0; i2--) {
                defaultGradient = getLayers()[i2].backwardGradient(feedForward.get(i2), defaultGradient);
                gradientArr[i2] = defaultGradient;
            }
            for (int i3 = 0; i3 < getnLayers(); i3++) {
                Gradient calcGradient = getLayers()[i3].calcGradient(gradientArr[i3], feedForward.get(i3));
                GradientAdjustment.updateGradientAccordingToParams(getLayers()[i3].conf(), i, calcGradient, iNDArray.slices(), getLayers()[i3].getOptimizer().adaGradForVariables(), getLayers()[i3]);
                getLayers()[i3].update(calcGradient);
            }
            Iterator<IterationListener> it = this.listeners.iterator();
            while (it.hasNext()) {
                it.next().iterationDone(getOutputLayer(), i);
            }
        }
    }

    public List<IterationListener> getListeners() {
        return this.listeners;
    }

    public void setListeners(List<IterationListener> list) {
        this.listeners = list;
        if (this.layers == null) {
            init();
        }
        for (Layer layer : this.layers) {
            layer.setIterationListeners(list);
        }
    }

    public void finetune(DataSetIterator dataSetIterator) {
        log.info("Finetune phase ");
        dataSetIterator.reset();
        while (dataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) dataSetIterator.next();
            if (dataSet.getFeatureMatrix() == null || dataSet.getLabels() == null) {
                return;
            }
            setInput(dataSet.getFeatureMatrix());
            setLabels(dataSet.getLabels());
            if (getOutputLayer().conf().getOptimizationAlgo() == OptimizationAlgorithm.HESSIAN_FREE) {
                throw new UnsupportedOperationException();
            }
            feedForward();
            if (getOutputLayer() instanceof OutputLayer) {
                OutputLayer outputLayer = (OutputLayer) getOutputLayer();
                outputLayer.fit(outputLayer.input(), getLabels());
            }
        }
    }

    public void finetune(INDArray iNDArray) {
        if (iNDArray != null) {
            this.labels = iNDArray;
        }
        if (!(getOutputLayer() instanceof OutputLayer)) {
            log.warn("Output layer not instance of output layer returning.");
            return;
        }
        log.info("Finetune phase");
        OutputLayer outputLayer = (OutputLayer) getOutputLayer();
        if (getOutputLayer().conf().getOptimizationAlgo() == OptimizationAlgorithm.HESSIAN_FREE) {
            throw new UnsupportedOperationException();
        }
        List<INDArray> feedForward = feedForward();
        outputLayer.fit(feedForward.get(feedForward.size() - 2), iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int[] predict(INDArray iNDArray) {
        INDArray output = output(iNDArray);
        int[] iArr = new int[iNDArray.slices()];
        if (iNDArray.isRowVector()) {
            iArr[0] = Nd4j.getBlasWrapper().iamax(output);
        } else {
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
            }
        }
        return iArr;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public INDArray labelProbabilities(INDArray iNDArray) {
        List<INDArray> feedForward = feedForward(iNDArray);
        return ((OutputLayer) getOutputLayer()).labelProbabilities(feedForward.get(feedForward.size() - 1));
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        setInput(iNDArray);
        if (this.layerWiseConfigurations.isPretrain()) {
            pretrain(getInput());
            finetune(iNDArray2);
        }
        if (this.layerWiseConfigurations.isBackward()) {
            doBackWard(getInput(), iNDArray2);
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
        pretrain(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray) {
        pretrain(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        fit(dataSet.getFeatureMatrix(), dataSet.getLabels());
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr) {
        fit(iNDArray, FeatureUtil.toOutcomeMatrix(iArr, getOutputLayer().conf().getnOut()));
    }

    public INDArray output(INDArray iNDArray) {
        List<INDArray> feedForward = feedForward(iNDArray);
        return feedForward.get(feedForward.size() - 1);
    }

    public INDArray reconstruct(INDArray iNDArray, int i) {
        return feedForward(iNDArray).get(i - 1);
    }

    public void printConfiguration() {
        StringBuilder sb = new StringBuilder();
        int i = 0;
        Iterator<NeuralNetConfiguration> it = getLayerWiseConfigurations().getConfs().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            sb.append(" Layer " + i2 + " conf " + it.next());
        }
        log.info(sb.toString());
    }

    public void update(MultiLayerNetwork multiLayerNetwork) {
        this.defaultConfiguration = multiLayerNetwork.defaultConfiguration;
        this.input = multiLayerNetwork.input;
        this.labels = multiLayerNetwork.labels;
        this.layers = (Layer[]) ArrayUtils.clone(multiLayerNetwork.layers);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double score(INDArray iNDArray, INDArray iNDArray2) {
        feedForward(iNDArray);
        setLabels(iNDArray2);
        Evaluation evaluation = new Evaluation();
        evaluation.eval(iNDArray2, labelProbabilities(iNDArray));
        return evaluation.f1();
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int numLabels() {
        return this.labels.columns();
    }

    public double score(DataSet dataSet) {
        feedForward(dataSet.getFeatureMatrix());
        setLabels(dataSet.getLabels());
        return score();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit() {
        fit(this.input, this.labels);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(Gradient gradient) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        if (getOutputLayer().input() == null) {
            feedForward();
        }
        return getOutputLayer().score();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setScore() {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void accumulateScore(double d) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void clear() {
        for (Layer layer : this.layers) {
            layer.clear();
        }
        this.input = null;
    }

    public double score(INDArray iNDArray) {
        INDArray params = params();
        setParameters(iNDArray);
        double score = score();
        double l2 = 0.5d * this.defaultConfiguration.getL2() * ((Double) Transforms.pow(this.mask.mul(iNDArray), 2).sum(Integer.MAX_VALUE).element()).doubleValue();
        setParameters(params);
        return score + l2;
    }

    public void merge(MultiLayerNetwork multiLayerNetwork, int i) {
        if (multiLayerNetwork.layers.length != this.layers.length) {
            throw new IllegalArgumentException("Unable to merge networks that are not of equal length");
        }
        for (int i2 = 0; i2 < getnLayers(); i2++) {
            this.layers[i2].merge(multiLayerNetwork.layers[i2], i);
        }
        getOutputLayer().merge(multiLayerNetwork.getOutputLayer(), i);
    }

    public void setInput(INDArray iNDArray) {
        if (getLayerWiseConfigurations().getInputPreProcess(0) != null) {
            this.input = this.layerWiseConfigurations.getInputPreProcess(0).preProcess(iNDArray);
        } else {
            this.input = iNDArray;
        }
        if (this.layers == null) {
            initializeLayers(getInput());
        } else if (this.input == null) {
            this.input = iNDArray;
        }
    }

    private void initMask() {
        setMask(Nd4j.ones(1, pack().length()));
    }

    public Layer getOutputLayer() {
        return getLayers()[getLayers().length - 1];
    }

    public void setParameters(INDArray iNDArray) {
        int i = 0;
        for (int i2 = 0; i2 < getLayers().length; i2++) {
            Layer layer = getLayers()[i2];
            int numParams = layer.numParams();
            INDArray iNDArray2 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(i, numParams + i)});
            if (iNDArray2.length() < 1) {
                throw new IllegalStateException("Unable to retrieve layer. No params found (length was 0");
            }
            layer.setParams(iNDArray2);
            i += numParams - 1;
        }
    }

    public List<INDArray> feedForwardR(List<INDArray> list, INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(Nd4j.zeros(this.input.slices(), this.input.columns()));
        List<Pair<INDArray, INDArray>> unPack = unPack(iNDArray);
        List<INDArray> weightMatrices = MultiLayerUtil.weightMatrices(this);
        for (int i = 0; i < this.layers.length; i++) {
            arrayList.add(((INDArray) arrayList.get(i)).mmul(weightMatrices.get(i)).addi(list.get(i).mmul(unPack.get(i).getFirst().addiRowVector(unPack.get(i).getSecond()))).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(getLayers()[i].conf().getActivationFunction(), list.get(i + 1)).derivative())));
        }
        return arrayList;
    }

    public List<INDArray> feedForwardR(INDArray iNDArray) {
        return feedForwardR(feedForward(), iNDArray);
    }

    protected List<Pair<INDArray, INDArray>> backPropGradientR(INDArray iNDArray) {
        if (this.mask == null) {
            initMask();
        }
        List<INDArray> computeDeltasR = computeDeltasR(iNDArray);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < getnLayers(); i++) {
            INDArray iNDArray2 = computeDeltasR.get(i);
            if (iNDArray2.length() != getLayers()[i].getParam(DefaultParamInitializer.WEIGHT_KEY).length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            INDArray mean = computeDeltasR.get(i).mean(0);
            if (mean.length() != this.layers[i].getParam("b").length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }
            arrayList.add(new Pair<>(iNDArray2, mean));
        }
        return unPack(pack(arrayList).addi(this.mask.mul(Double.valueOf(this.defaultConfiguration.getL2())).muli(iNDArray)).addi(iNDArray.mul(Double.valueOf(this.layerWiseConfigurations.getDampingFactor()))));
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public INDArray getInput() {
        return this.input;
    }

    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    public int getnLayers() {
        return this.layerWiseConfigurations.getConfs().size();
    }

    public Layer[] getLayers() {
        return this.layers;
    }

    public void setLayers(Layer[] layerArr) {
        this.layers = layerArr;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setMask(INDArray iNDArray) {
        this.mask = iNDArray;
    }
}
