/*
 * Decompiled with CFR 0.152.
 */
package org.wlld.rnnNerveEntity;

import java.util.Map;
import org.wlld.i.ActiveFunction;
import org.wlld.i.OutBack;
import org.wlld.matrixTools.Matrix;
import org.wlld.rnnNerveEntity.Nerve;

public class OutNerve
extends Nerve {
    private Map<Integer, Matrix> matrixMapE;
    private boolean isShowLog;
    private boolean isSoftMax;

    public OutNerve(int id, int upNub, int downNub, double studyPoint, boolean init, ActiveFunction activeFunction, boolean isDynamic, boolean isShowLog, int rzType, double lParam, boolean isSoftMax, int step, int kernLen) throws Exception {
        super(id, upNub, "OutNerve", downNub, studyPoint, init, activeFunction, isDynamic, rzType, lParam, step, kernLen, 0);
        this.isShowLog = isShowLog;
        this.isSoftMax = isSoftMax;
    }

    void getGBySoftMax(double g, long eventId) throws Exception {
        this.gradient = g;
        this.updatePower(eventId);
    }

    public void setMatrixMap(Map<Integer, Matrix> matrixMap) {
        this.matrixMapE = matrixMap;
    }

    @Override
    public void input(long eventId, double parameter, boolean isStudy, Map<Integer, Double> E, OutBack outBack, boolean isEmbedding, Matrix rnnMatrix) throws Exception {
        boolean allReady = this.insertParameter(eventId, parameter);
        if (allReady) {
            double sigma = this.calculation(eventId);
            if (this.isSoftMax) {
                if (!isStudy) {
                    this.destoryParameter(eventId);
                }
                this.sendMessage(eventId, sigma, isStudy, E, outBack, false, rnnMatrix);
            } else {
                double out = this.activeFunction.function(sigma);
                if (isStudy) {
                    this.outNub = out;
                    this.E = E.containsKey(this.getId()) ? E.get(this.getId()) : 0.0;
                    if (this.isShowLog) {
                        System.out.println("E==" + this.E + ",out==" + out + ",nerveId==" + this.getId());
                    }
                    this.gradient = this.outGradient();
                    this.updatePower(eventId);
                } else {
                    this.destoryParameter(eventId);
                    if (outBack != null) {
                        outBack.getBack(out, this.getId(), eventId);
                    } else {
                        throw new Exception("not find outBack");
                    }
                }
            }
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    protected void inputMatrix(long eventId, Matrix matrix, boolean isKernelStudy, int E, OutBack outBack) throws Exception {
        Matrix myMatrix = this.conv(matrix);
        if (isKernelStudy) {
            Matrix matrix1 = this.matrixMapE.get(E);
            if (this.isShowLog) {
                System.out.println("E========" + E);
                System.out.println(myMatrix.getString());
            }
            if (matrix1.getX() != myMatrix.getX() || matrix1.getY() != myMatrix.getY()) throw new Exception("Wrong size setting of image in templateConfig");
            Matrix g = this.getGradient(myMatrix, matrix1);
            this.backMatrix(g);
            return;
        } else {
            if (outBack == null) throw new Exception("not find outBack");
            outBack.getBackMatrix(myMatrix, this.getId(), eventId);
        }
    }

    private Matrix getGradient(Matrix matrix, Matrix E) throws Exception {
        Matrix matrix1 = new Matrix(matrix.getX(), matrix.getY());
        for (int i = 0; i < E.getX(); ++i) {
            for (int j = 0; j < E.getY(); ++j) {
                double nub = E.getNumber(i, j) - matrix.getNumber(i, j);
                matrix1.setNub(i, j, nub);
            }
        }
        return matrix1;
    }

    private double outGradient() {
        return this.activeFunction.functionG(this.outNub) * (this.E - this.outNub);
    }
}

