/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.transFormer.nerve;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.dromara.i.ActiveFunction;
import org.dromara.i.OutBack;
import org.dromara.matrixTools.Matrix;
import org.dromara.matrixTools.MatrixOperation;
import org.dromara.transFormer.LineBlock;
import org.dromara.transFormer.seflAttention.LayNorm;

public abstract class Nerve {
    private final List<Nerve> son = new ArrayList<Nerve>();
    private final List<Nerve> father = new ArrayList<Nerve>();
    protected LayNorm beforeLayNorm;
    protected LayNorm afterLayNorm;
    protected Matrix powerMatrix;
    private final int id;
    private final int hiddenNerveNub;
    private final int sensoryNerveNub;
    private final int outNerveNub;
    protected Map<Long, Matrix> reMatrixFeatures = new HashMap<Long, Matrix>();
    protected String name;
    protected Matrix featureMatrix;
    protected double E;
    protected double studyPoint;
    protected LineBlock lineBlock;
    protected Matrix sigmaW;
    private int backNub = 0;
    protected ActiveFunction activeFunction;
    protected Matrix outMatrix;
    protected int myUpNumber;
    protected int depth;
    private final int regularModel;
    private final double regular;
    private final MatrixOperation matrixOperation;

    public int getDepth() {
        return this.depth;
    }

    public void setBeforeLayNorm(LayNorm beforeLayNorm) {
        this.beforeLayNorm = beforeLayNorm;
    }

    public void setAfterLayNorm(LayNorm afterLayNorm) {
        this.afterLayNorm = afterLayNorm;
    }

    protected Nerve(int id, String name, double studyPoint, ActiveFunction activeFunction, int sensoryNerveNub, int hiddenNerveNub, int outNerveNub, LineBlock lineBlock, int regularModel, double regular, int coreNumber) throws Exception {
        this.id = id;
        this.matrixOperation = new MatrixOperation(coreNumber);
        this.regular = regular;
        this.regularModel = regularModel;
        this.lineBlock = lineBlock;
        this.hiddenNerveNub = hiddenNerveNub;
        this.sensoryNerveNub = sensoryNerveNub;
        this.outNerveNub = outNerveNub;
        this.name = name;
        this.studyPoint = studyPoint;
        this.activeFunction = activeFunction;
        this.initPower();
    }

    public double[][] getModel() {
        return this.powerMatrix.getMatrix();
    }

    public void insertModel(double[][] modelPower) throws Exception {
        for (int i = 0; i < this.powerMatrix.getX(); ++i) {
            for (int j = 0; j < this.powerMatrix.getY(); ++j) {
                this.powerMatrix.setNub(i, j, modelPower[i][j]);
            }
        }
    }

    protected void sendMessage(long eventId, Matrix parameter, boolean isStudy, Matrix allFeature, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
        if (!this.son.isEmpty()) {
            for (Nerve nerve : this.son) {
                nerve.input(eventId, parameter, isStudy, allFeature, outBack, E, encoderFeature, outAllPro);
            }
        }
    }

    private void backSendMessage(long eventId, Matrix errorMatrix, Matrix allError) throws Exception {
        if (!this.father.isEmpty()) {
            if (errorMatrix.getY() - 1 != this.father.size()) {
                throw new Exception("\u56de\u4f20\u53c2\u6570\u6570\u91cf\u4e0d\u4e00\u81f4!");
            }
            for (int i = 0; i < this.father.size(); ++i) {
                this.father.get(i).backGetMessage(errorMatrix.getColumn(i), eventId, allError);
            }
        } else if (this.lineBlock != null) {
            this.lineBlock.backError(eventId, errorMatrix);
        } else {
            this.afterLayNorm.backErrorFromFNN(errorMatrix, eventId, allError);
        }
    }

    protected void input(long eventId, Matrix parameter, boolean isStudy, Matrix allFeature, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
    }

    protected void toOut(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
    }

    protected void sendOutMessage(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
        if (!this.son.isEmpty()) {
            for (Nerve nerve : this.son) {
                nerve.toOut(eventId, parameter, isStudy, outBack, E, outAllPro);
            }
        }
    }

    private void backGetMessage(Matrix parameter, long eventId, Matrix allError) throws Exception {
        ++this.backNub;
        this.sigmaW = this.sigmaW == null ? parameter : this.matrixOperation.add(this.sigmaW, parameter);
        if (this.backNub == this.outNerveNub) {
            this.backNub = 0;
            if (this.activeFunction != null) {
                for (int i = 0; i < this.sigmaW.getX(); ++i) {
                    double out = this.outMatrix.getNumber(i, 0);
                    double value = this.activeFunction.functionG(out) * this.sigmaW.getNumber(i, 0);
                    this.sigmaW.setNub(i, 0, value);
                }
            }
            this.updatePower(eventId, this.sigmaW, allError);
        }
    }

    protected void updatePower(long eventId, Matrix errorMatrix, Matrix allError) throws Exception {
        Matrix myError = this.matrixOperation.mathMulBySelf(errorMatrix, this.studyPoint);
        Matrix error = this.updateW(myError, errorMatrix);
        this.sigmaW = null;
        this.backSendMessage(eventId, error, allError);
    }

    private Matrix getRegularizationMatrix() throws Exception {
        int size = this.powerMatrix.getX();
        double sigma = 0.0;
        for (int i = 0; i < size; ++i) {
            double value = this.powerMatrix.getNumber(i, 0);
            if (this.regularModel == 1) {
                sigma += Math.abs(value);
                continue;
            }
            sigma += Math.pow(value, 2.0);
        }
        double param = sigma * this.regular * this.studyPoint;
        Matrix rzMatrix = new Matrix(this.powerMatrix.getX(), this.powerMatrix.getY());
        for (int i = 0; i < size; ++i) {
            double value = this.powerMatrix.getNumber(i, 0);
            double re = 0.0;
            if (this.regularModel == 2) {
                re = param * -value;
            } else if (this.regularModel == 1) {
                if (value > 0.0) {
                    re = -param;
                } else if (value < 0.0) {
                    re = param;
                }
            }
            rzMatrix.setNub(i, 0, re);
        }
        return rzMatrix;
    }

    private Matrix updateW(Matrix errorMatrix, Matrix error) throws Exception {
        Matrix rzMatrix = null;
        if (this.regularModel != 0) {
            rzMatrix = this.getRegularizationMatrix();
        }
        Matrix subFeature = this.matrixOperation.matrixMulPd(error, this.featureMatrix, this.powerMatrix, true);
        Matrix subPower = this.matrixOperation.matrixMulPd(errorMatrix, this.featureMatrix, this.powerMatrix, false);
        if (this.regularModel != 0) {
            this.powerMatrix = this.matrixOperation.add(this.powerMatrix, rzMatrix);
        }
        this.powerMatrix = this.matrixOperation.add(this.powerMatrix, subPower);
        return subFeature;
    }

    protected boolean insertMatrixParameter(long eventID, Matrix matrix) throws Exception {
        Matrix feature;
        boolean allReady = false;
        if (this.reMatrixFeatures.containsKey(eventID)) {
            Matrix myFeature = this.reMatrixFeatures.get(eventID);
            feature = this.matrixOperation.pushVector(myFeature, matrix, false);
        } else {
            feature = matrix;
        }
        this.reMatrixFeatures.put(eventID, feature);
        if (feature.getY() == this.myUpNumber) {
            allReady = true;
        } else if (feature.getY() > this.myUpNumber) {
            throw new Exception("\u63a5\u6536\u77e9\u9635\u53c2\u6570\u6570\u91cf\u5f02\u5e38");
        }
        return allReady;
    }

    protected Matrix opMatrix(Matrix feature, boolean isStudy) throws Exception {
        Matrix th = new Matrix(feature.getX(), 1);
        for (int i = 0; i < th.getX(); ++i) {
            th.setNub(i, 0, 1.0);
        }
        Matrix matrix = this.matrixOperation.pushVector(feature, th, false);
        Matrix sigma = this.matrixOperation.mulMatrix(matrix, this.powerMatrix);
        if (this.activeFunction != null) {
            for (int i = 0; i < sigma.getX(); ++i) {
                double value = this.activeFunction.function(sigma.getNumber(i, 0));
                sigma.setNub(i, 0, value);
            }
        }
        if (isStudy) {
            this.featureMatrix = matrix;
            this.outMatrix = sigma;
        }
        return sigma;
    }

    private void initPower() throws Exception {
        Random random = new Random();
        this.myUpNumber = this.name.equals("HiddenNerve") ? this.sensoryNerveNub : (this.name.equals("OutNerve") ? this.hiddenNerveNub : this.outNerveNub);
        if (this.myUpNumber > 0) {
            this.powerMatrix = new Matrix(this.myUpNumber + 1, 1);
            double sh = Math.sqrt(this.myUpNumber);
            for (int i = 0; i < this.myUpNumber; ++i) {
                double nub = random.nextDouble() / sh;
                this.powerMatrix.setNub(i, 0, nub);
            }
            this.powerMatrix.setNub(this.myUpNumber, 0, random.nextDouble() / sh);
        }
    }

    public int getId() {
        return this.id;
    }

    public void connect(List<Nerve> nerveList) {
        this.son.addAll(nerveList);
    }

    public void connectFather(List<Nerve> nerveList) {
        this.father.addAll(nerveList);
    }
}

