package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/GravesLSTM.class */
public class GravesLSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.GravesLSTM> {
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";
    public static final String STATE_KEY_PREV_MEMCELL = "prevMem";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/GravesLSTM$FwdPassReturn.class */
    public static class FwdPassReturn {
        private INDArray fwdPassOutput;
        private INDArray[] paramsZeroOffset;
        private INDArray[] fwdPassOutputAsArrays;
        private INDArray[] memCellState;
        private INDArray[] iz;
        private INDArray[] ia;
        private INDArray[] fz;
        private INDArray[] fa;
        private INDArray[] oz;
        private INDArray[] oa;
        private INDArray[] gz;
        private INDArray[] ga;
        private INDArray lastAct;
        private INDArray lastMemCell;

        private FwdPassReturn() {
        }
    }

    public GravesLSTM(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public GravesLSTM(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        FwdPassReturn activateHelper = activateHelper(true, null, null, true);
        INDArray param = getParam("W");
        int size = getParam("RW").size(0);
        int size2 = param.size(0);
        int size3 = iNDArray.size(0);
        boolean z = iNDArray.rank() < 3;
        int size4 = z ? 1 : iNDArray.size(2);
        activateHelper.paramsZeroOffset[0] = Shape.toOffsetZero(activateHelper.paramsZeroOffset[0].transpose());
        activateHelper.paramsZeroOffset[2] = Shape.toOffsetZero(activateHelper.paramsZeroOffset[2].transpose());
        activateHelper.paramsZeroOffset[5] = Shape.toOffsetZero(activateHelper.paramsZeroOffset[5].transpose());
        activateHelper.paramsZeroOffset[8] = Shape.toOffsetZero(activateHelper.paramsZeroOffset[8].transpose());
        INDArray iNDArray2 = activateHelper.paramsZeroOffset[0];
        INDArray iNDArray3 = activateHelper.paramsZeroOffset[1];
        INDArray iNDArray4 = activateHelper.paramsZeroOffset[2];
        INDArray iNDArray5 = activateHelper.paramsZeroOffset[3];
        INDArray iNDArray6 = activateHelper.paramsZeroOffset[4];
        INDArray iNDArray7 = activateHelper.paramsZeroOffset[5];
        INDArray iNDArray8 = activateHelper.paramsZeroOffset[6];
        INDArray iNDArray9 = activateHelper.paramsZeroOffset[7];
        INDArray iNDArray10 = activateHelper.paramsZeroOffset[8];
        INDArray iNDArray11 = activateHelper.paramsZeroOffset[9];
        INDArray iNDArray12 = activateHelper.paramsZeroOffset[10];
        INDArray[] iNDArrayArr = new INDArray[4];
        INDArray[] iNDArrayArr2 = new INDArray[4];
        INDArray[] iNDArrayArr3 = new INDArray[7];
        for (int i = 0; i < 4; i++) {
            iNDArrayArr[i] = Nd4j.zeros(1, size);
            iNDArrayArr2[i] = Nd4j.zeros(size2, size);
            iNDArrayArr3[i] = Nd4j.zeros(size, size);
        }
        for (int i2 = 0; i2 < 3; i2++) {
            iNDArrayArr3[i2 + 4] = Nd4j.zeros(1, size);
        }
        INDArray zeros = Nd4j.zeros(new int[]{size3, size2, size4});
        INDArray zeros2 = Nd4j.zeros(size3, size);
        INDArray iNDArray13 = null;
        INDArray zeros3 = Nd4j.zeros(size3, size);
        INDArray iNDArray14 = null;
        INDArray zeros4 = Nd4j.zeros(size3, size);
        int i3 = size4 - 1;
        while (i3 >= 0) {
            INDArray zeros5 = i3 == 0 ? Nd4j.zeros(size3, size) : activateHelper.memCellState[i3 - 1];
            INDArray iNDArray15 = i3 == 0 ? null : activateHelper.fwdPassOutputAsArrays[i3 - 1];
            INDArray iNDArray16 = activateHelper.memCellState[i3];
            INDArray tensorAlongDimension = z ? iNDArray : iNDArray.tensorAlongDimension(i3, new int[]{1, 0});
            if (i3 != size4 - 1) {
                tensorAlongDimension = tensorAlongDimension.dup();
                tensorAlongDimension.addi(iNDArray13.mmul(iNDArray3.transpose())).addi(zeros3.mmul(iNDArray5.transpose())).addi(iNDArray14.mmul(iNDArray8.transpose())).addi(zeros4.mmul(iNDArray11.transpose()));
            }
            INDArray muli = tensorAlongDimension.mul(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), iNDArray16.dup()))).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", activateHelper.oz[i3]).derivative()));
            INDArray addi = tensorAlongDimension.mul(activateHelper.oa[i3]).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), iNDArray16.dup()).derivative())).addi((i3 == size4 - 1 ? Nd4j.zeros(size3, size) : activateHelper.fa[i3 + 1]).mul(zeros2)).addi(zeros3.mulRowVector(iNDArray6.transpose())).addi(muli.mulRowVector(iNDArray9.transpose())).addi(zeros4.mulRowVector(iNDArray12.transpose()));
            zeros2 = addi;
            INDArray muli2 = addi.mul(zeros5).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", activateHelper.fz[i3]).derivative()));
            INDArray muli3 = addi.mul(activateHelper.ia[i3]).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", activateHelper.gz[i3]).derivative()));
            INDArray muli4 = addi.mul(activateHelper.ga[i3]).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), activateHelper.iz[i3]).derivative()));
            INDArray offsetZero = Shape.toOffsetZero(z ? this.input.transpose() : this.input.tensorAlongDimension(i3, new int[]{1, 0}).transpose());
            iNDArrayArr2[0].addi(offsetZero.mmul(muli4));
            iNDArrayArr2[1].addi(offsetZero.mmul(muli2));
            iNDArrayArr2[2].addi(offsetZero.mmul(muli));
            iNDArrayArr2[3].addi(offsetZero.mmul(muli3));
            if (i3 > 0) {
                INDArray offsetZero2 = Shape.toOffsetZero(iNDArray15.transpose());
                iNDArrayArr3[0].addi(offsetZero2.mmul(muli4));
                iNDArrayArr3[1].addi(offsetZero2.mmul(muli2));
                iNDArrayArr3[2].addi(offsetZero2.mmul(muli));
                iNDArrayArr3[3].addi(offsetZero2.mmul(muli3));
                iNDArrayArr3[4].addi(muli2.mul(zeros5).sum(new int[]{0}));
                iNDArrayArr3[6].addi(muli3.mul(zeros5).sum(new int[]{0}));
            }
            iNDArrayArr3[5].addi(muli.mul(iNDArray16).sum(new int[]{0}));
            iNDArrayArr[0].addi(muli4.sum(new int[]{0}));
            iNDArrayArr[1].addi(muli2.sum(new int[]{0}));
            iNDArrayArr[2].addi(muli.sum(new int[]{0}));
            iNDArrayArr[3].addi(muli3.sum(new int[]{0}));
            zeros.tensorAlongDimension(i3, new int[]{1, 0}).assign(muli4.mmul(iNDArray2).addi(muli2.mmul(iNDArray4)).addi(muli.mmul(iNDArray7)).addi(muli3.mmul(iNDArray10)));
            iNDArray13 = muli4;
            zeros3 = muli2;
            iNDArray14 = muli;
            zeros4 = muli3;
            i3--;
        }
        INDArray zeros6 = Nd4j.zeros(size2, 4 * size);
        INDArray zeros7 = Nd4j.zeros(size, (4 * size) + 3);
        INDArray hstack = Nd4j.hstack(iNDArrayArr);
        zeros6.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}, iNDArrayArr2[0]);
        zeros6.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, iNDArrayArr2[1]);
        zeros6.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}, iNDArrayArr2[2]);
        zeros6.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)}, iNDArrayArr2[3]);
        zeros7.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}, iNDArrayArr3[0]);
        zeros7.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, iNDArrayArr3[1]);
        zeros7.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}, iNDArrayArr3[2]);
        zeros7.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size)}, iNDArrayArr3[3]);
        zeros7.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point(4 * size)}, iNDArrayArr3[4].transpose());
        zeros7.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 1)}, iNDArrayArr3[5].transpose());
        zeros7.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 2)}, iNDArrayArr3[6].transpose());
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientForVariable().put("W", zeros6);
        defaultGradient.gradientForVariable().put("RW", zeros7);
        defaultGradient.gradientForVariable().put("b", hstack);
        return new Pair<>(defaultGradient, zeros);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        return activate(iNDArray, true);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        return activate(iNDArray, z);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        setInput(iNDArray, z);
        return activateHelper(z, null, null, false).fwdPassOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        setInput(iNDArray);
        return activateHelper(true, null, null, false).fwdPassOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        return activateHelper(z, null, null, false).fwdPassOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return activateHelper(false, null, null, false).fwdPassOutput;
    }

    private FwdPassReturn activateHelper(boolean z, INDArray iNDArray, INDArray iNDArray2, boolean z2) {
        INDArray param = getParam("RW");
        INDArray param2 = getParam("W");
        INDArray param3 = getParam("b");
        boolean z3 = this.input.rank() < 3;
        int size = z3 ? 1 : this.input.size(2);
        int size2 = param.size(0);
        int size3 = this.input.size(0);
        if (this.conf.isUseDropConnect() && z && this.conf.getLayer().getDropOut() > 0.0d) {
            param2 = Dropout.applyDropConnect(this, "W");
        }
        INDArray iNDArray3 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray4 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray5 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray6 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray7 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray8 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(4 * size2, (4 * size2) + 1)});
        INDArray iNDArray9 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray10 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray iNDArray11 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray iNDArray12 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((4 * size2) + 1, (4 * size2) + 2)});
        INDArray iNDArray13 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray iNDArray14 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size2, 4 * size2)});
        INDArray iNDArray15 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(3 * size2, 4 * size2)});
        INDArray iNDArray16 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((4 * size2) + 2, (4 * size2) + 3)});
        INDArray iNDArray17 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(3 * size2, 4 * size2)});
        if (size > 1 || z2) {
            iNDArray3 = Shape.toOffsetZero(iNDArray3);
            iNDArray4 = Shape.toOffsetZero(iNDArray4);
            iNDArray6 = Shape.toOffsetZero(iNDArray6);
            iNDArray7 = Shape.toOffsetZero(iNDArray7);
            iNDArray8 = Shape.toOffsetZero(iNDArray8);
            iNDArray10 = Shape.toOffsetZero(iNDArray10);
            iNDArray11 = Shape.toOffsetZero(iNDArray11);
            iNDArray12 = Shape.toOffsetZero(iNDArray12);
            iNDArray14 = Shape.toOffsetZero(iNDArray14);
            iNDArray15 = Shape.toOffsetZero(iNDArray15);
            iNDArray16 = Shape.toOffsetZero(iNDArray16);
            iNDArray5 = Shape.toOffsetZero(iNDArray5);
            iNDArray9 = Shape.toOffsetZero(iNDArray9);
            iNDArray13 = Shape.toOffsetZero(iNDArray13);
            iNDArray17 = Shape.toOffsetZero(iNDArray17);
        }
        INDArray iNDArray18 = null;
        FwdPassReturn fwdPassReturn = new FwdPassReturn();
        if (z2) {
            fwdPassReturn.paramsZeroOffset = new INDArray[]{iNDArray3, iNDArray4, iNDArray6, iNDArray7, iNDArray8, iNDArray10, iNDArray11, iNDArray12, iNDArray14, iNDArray15, iNDArray16};
            fwdPassReturn.fwdPassOutputAsArrays = new INDArray[size];
            fwdPassReturn.memCellState = new INDArray[size];
            fwdPassReturn.iz = new INDArray[size];
            fwdPassReturn.ia = new INDArray[size];
            fwdPassReturn.fz = new INDArray[size];
            fwdPassReturn.fa = new INDArray[size];
            fwdPassReturn.oz = new INDArray[size];
            fwdPassReturn.oa = new INDArray[size];
            fwdPassReturn.gz = new INDArray[size];
            fwdPassReturn.ga = new INDArray[size];
        } else {
            iNDArray18 = Nd4j.zeros(new int[]{size3, size2, size});
            fwdPassReturn.fwdPassOutput = iNDArray18;
        }
        if (iNDArray == null) {
            iNDArray = Nd4j.zeros(new int[]{size3, size2});
        }
        if (iNDArray2 == null) {
            iNDArray2 = Nd4j.zeros(new int[]{size3, size2});
        }
        for (int i = 0; i < size; i++) {
            INDArray offsetZero = Shape.toOffsetZero(z3 ? this.input : this.input.tensorAlongDimension(i, new int[]{1, 0}));
            INDArray addiRowVector = offsetZero.mmul(iNDArray3).addi(iNDArray.mmul(iNDArray4)).addiRowVector(iNDArray5);
            if (z2) {
                fwdPassReturn.iz[i] = addiRowVector.dup();
            }
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), addiRowVector));
            if (z2) {
                fwdPassReturn.ia[i] = addiRowVector;
            }
            INDArray addiRowVector2 = offsetZero.mmul(iNDArray6).addi(iNDArray.mmul(iNDArray7)).addi(iNDArray2.mulRowVector(iNDArray8.transpose())).addiRowVector(iNDArray9);
            if (z2) {
                fwdPassReturn.fz[i] = addiRowVector2.dup();
            }
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", addiRowVector2));
            if (z2) {
                fwdPassReturn.fa[i] = addiRowVector2;
            }
            INDArray addiRowVector3 = offsetZero.mmul(iNDArray14).addi(iNDArray.mmul(iNDArray15)).addi(iNDArray2.mulRowVector(iNDArray16.transpose())).addiRowVector(iNDArray17);
            if (z2) {
                fwdPassReturn.gz[i] = addiRowVector3.dup();
            }
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", addiRowVector3));
            if (z2) {
                fwdPassReturn.ga[i] = addiRowVector3;
            }
            INDArray addi = addiRowVector2.mul(iNDArray2).addi(addiRowVector3.mul(addiRowVector));
            INDArray addiRowVector4 = offsetZero.mmul(iNDArray10).addi(iNDArray.mmul(iNDArray11)).addi(addi.mulRowVector(iNDArray12.transpose())).addiRowVector(iNDArray13);
            if (z2) {
                fwdPassReturn.oz[i] = addiRowVector4.dup();
            }
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", addiRowVector4));
            if (z2) {
                fwdPassReturn.oa[i] = addiRowVector4;
            }
            INDArray muli = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), addi.dup())).muli(addiRowVector4);
            if (z2) {
                fwdPassReturn.fwdPassOutputAsArrays[i] = muli;
                fwdPassReturn.memCellState[i] = addi;
            } else {
                iNDArray18.tensorAlongDimension(i, new int[]{1, 0}).assign(muli);
            }
            iNDArray = muli;
            iNDArray2 = addi;
            fwdPassReturn.lastAct = muli;
            fwdPassReturn.lastMemCell = addi;
        }
        return fwdPassReturn;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activationMean() {
        return activate();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getL2() <= 0.0d) {
            return 0.0d;
        }
        return 0.5d * this.conf.getL2() * (Transforms.pow(getParam("RW"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.pow(getParam("W"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getL1() <= 0.0d) {
            return 0.0d;
        }
        return this.conf.getL1() * (Transforms.abs(getParam("RW")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.abs(getParam("W")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0));
    }

    @Override // org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer
    public INDArray rnnTimeStep(INDArray iNDArray) {
        setInput(iNDArray);
        FwdPassReturn activateHelper = activateHelper(false, this.stateMap.get("prevAct"), this.stateMap.get(STATE_KEY_PREV_MEMCELL), false);
        INDArray iNDArray2 = activateHelper.fwdPassOutput;
        this.stateMap.put("prevAct", activateHelper.lastAct);
        this.stateMap.put(STATE_KEY_PREV_MEMCELL, activateHelper.lastMemCell);
        return iNDArray2;
    }
}
