/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.transFormer.seflAttention;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.dromara.i.OutBack;
import org.dromara.matrixTools.Matrix;
import org.dromara.matrixTools.MatrixOperation;
import org.dromara.transFormer.CodecBlock;
import org.dromara.transFormer.FirstDecoderBlock;
import org.dromara.transFormer.model.LayNormModel;
import org.dromara.transFormer.nerve.HiddenNerve;
import org.dromara.transFormer.seflAttention.MultiSelfAttention;

public class LayNorm {
    private MultiSelfAttention multiSelfAttention;
    private final CodecBlock myEncoderBlock;
    private final int featureDimension;
    private List<HiddenNerve> hiddenNerves;
    private final int type;
    private final Map<Long, Matrix> reMatrixMap = new HashMap<Long, Matrix>();
    private final FirstDecoderBlock firstDecoderBlock;
    private Matrix bTa;
    private Matrix power;
    private Matrix myNormData;
    private final double study;
    private Matrix myFinalError;
    private int number;
    private final MatrixOperation matrixOperation;

    public LayNormModel getModel() {
        LayNormModel layNormModel = new LayNormModel();
        layNormModel.setbTa(this.bTa.getMatrix());
        layNormModel.setPower(this.power.getMatrix());
        return layNormModel;
    }

    public void insertModel(LayNormModel layNormModel) throws Exception {
        this.insertPower(layNormModel.getPower(), this.power);
        this.insertPower(layNormModel.getbTa(), this.bTa);
    }

    private void insertPower(double[][] modelPower, Matrix power) throws Exception {
        for (int i = 0; i < power.getX(); ++i) {
            for (int j = 0; j < power.getY(); ++j) {
                power.setNub(i, j, modelPower[i][j]);
            }
        }
    }

    public LayNorm(int type, int featureDimension, CodecBlock myEncoderBlock, FirstDecoderBlock firstDecoderBlock, double study, int coreNumber) throws Exception {
        int i;
        this.study = study;
        this.myEncoderBlock = myEncoderBlock;
        this.type = type;
        this.featureDimension = featureDimension;
        this.firstDecoderBlock = firstDecoderBlock;
        this.matrixOperation = new MatrixOperation(coreNumber);
        this.bTa = new Matrix(1, featureDimension);
        this.power = new Matrix(featureDimension, featureDimension);
        Random random = new Random();
        double sh = Math.sqrt(featureDimension);
        for (i = 0; i < featureDimension; ++i) {
            double value = random.nextDouble() / sh;
            this.bTa.setNub(0, i, value);
        }
        for (i = 0; i < featureDimension; ++i) {
            for (int j = 0; j < featureDimension; ++j) {
                double value = random.nextDouble() / sh;
                this.power.setNub(i, j, value);
            }
        }
    }

    private Matrix back(Matrix errorMatrix, Matrix myData) throws Exception {
        Matrix subPower = this.matrixOperation.matrixMulPd(errorMatrix, myData, this.power, false);
        Matrix sub = this.matrixOperation.matrixMulPd(errorMatrix, myData, this.power, true);
        this.power = this.matrixOperation.add(subPower, this.power);
        double n = Math.sqrt(sub.getY());
        double nt = -n / (n - 1.0);
        Matrix subMatrix = new Matrix(1, sub.getY());
        for (int i = 0; i < sub.getY(); ++i) {
            double subValue = sub.getNumber(0, i);
            double value = subValue * n * this.study + subMatrix.getNumber(0, i);
            subMatrix.setNub(0, i, value);
            for (int j = 0; j < sub.getY(); ++j) {
                if (i == j) continue;
                double otherValue = subValue * nt * this.study + subMatrix.getNumber(0, j);
                subMatrix.setNub(0, j, otherValue);
            }
        }
        return subMatrix;
    }

    public void backErrorFromFNN(Matrix errorMatrix, long eventID, Matrix allError) throws Exception {
        ++this.number;
        this.myFinalError = this.myFinalError == null ? errorMatrix : this.matrixOperation.add(this.myFinalError, errorMatrix);
        if (this.number == this.featureDimension) {
            this.number = 0;
            Matrix error = this.myFinalError.getSonOfMatrix(0, 0, this.myFinalError.getX(), this.myFinalError.getY() - 1);
            this.myFinalError = null;
            Matrix myError = this.matrixOperation.add(error, allError);
            this.backErrorFromLine(myError, eventID);
        }
    }

    public void backLastError(Matrix errorMatrix) throws Exception {
        this.myFinalError = this.myFinalError == null ? errorMatrix : this.matrixOperation.add(this.myFinalError, errorMatrix);
    }

    public void encoderBackStart(long eventID) throws Exception {
        Matrix error = this.myFinalError.copy();
        this.myFinalError = null;
        this.backErrorFromLine(error, eventID);
    }

    public void backErrorFromLine(Matrix errorMatrix, long eventID) throws Exception {
        this.matrixOperation.mathMul(errorMatrix, this.study);
        int x = errorMatrix.getX();
        Matrix myError = null;
        for (int i = 0; i < x; ++i) {
            Matrix error = errorMatrix.getRow(i);
            Matrix myData = this.myNormData.getRow(i);
            this.bTa = this.matrixOperation.add(error, this.bTa);
            Matrix myRowError = this.back(error, myData);
            myError = i == 0 ? myRowError : this.matrixOperation.pushVector(myError, myRowError, true);
        }
        if (this.type == 2) {
            int size = this.hiddenNerves.size();
            for (int i = 0; i < size; ++i) {
                this.hiddenNerves.get(i).receiveErrorMatrix(myError.getColumn(i), eventID, myError);
            }
        } else {
            this.multiSelfAttention.backError(myError, eventID);
        }
    }

    public void addNorm(Matrix feature, Matrix outMatrix, long eventID, boolean isStudy, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
        Matrix myMatrix = this.matrixOperation.add(feature, outMatrix);
        Matrix out = this.layNorm(myMatrix, isStudy);
        if (this.type == 1) {
            if (this.myEncoderBlock != null) {
                this.sendHiddenParameter(out, eventID, isStudy, outBack, E, encoderFeature, outAllPro);
            } else if (this.firstDecoderBlock != null) {
                this.firstDecoderBlock.sendOutputMatrix(eventID, out, isStudy, outBack, E, outAllPro);
            }
        } else {
            this.myEncoderBlock.sendOutputMatrix(eventID, out, isStudy, outBack, E, encoderFeature, outAllPro);
        }
    }

    public void addNormFromNerve(long eventID, boolean isStudy, Matrix parameter, Matrix allFeature, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
        Matrix matrixFeature;
        if (this.reMatrixMap.containsKey(eventID)) {
            Matrix myFeature = this.reMatrixMap.get(eventID);
            matrixFeature = this.matrixOperation.pushVector(myFeature, parameter, false);
        } else {
            matrixFeature = parameter;
        }
        this.reMatrixMap.put(eventID, matrixFeature);
        if (matrixFeature.getY() == this.featureDimension) {
            this.reMatrixMap.remove(eventID);
            this.addNorm(matrixFeature, allFeature, eventID, isStudy, outBack, E, encoderFeature, outAllPro);
        }
    }

    private void sendHiddenParameter(Matrix feature, long eventId, boolean isStudy, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
        for (HiddenNerve hiddenNerve : this.hiddenNerves) {
            hiddenNerve.receive(feature, eventId, isStudy, outBack, E, encoderFeature, outAllPro);
        }
    }

    private Matrix norm(Matrix row) throws Exception {
        Matrix result = new Matrix(1, row.getY());
        double avg = row.getAVG();
        double sd = this.matrixOperation.getSdByMatrix(row, avg, 1.0E-5);
        for (int i = 0; i < row.getY(); ++i) {
            double value = (row.getNumber(0, i) - avg) / sd;
            result.setNub(0, i, value);
        }
        return result;
    }

    private Matrix layNorm(Matrix feature, boolean isStudy) throws Exception {
        int x = feature.getX();
        Matrix out = null;
        if (isStudy) {
            this.myNormData = null;
        }
        for (int i = 0; i < x; ++i) {
            Matrix normData = this.norm(feature.getRow(i));
            if (isStudy) {
                this.myNormData = i == 0 ? normData : this.matrixOperation.pushVector(this.myNormData, normData, true);
            }
            Matrix want = this.matrixOperation.add(this.matrixOperation.mulMatrix(normData, this.power), this.bTa);
            out = i == 0 ? want : this.matrixOperation.pushVector(out, want, true);
        }
        return out;
    }

    public void setHiddenNerves(List<HiddenNerve> hiddenNerves) {
        this.hiddenNerves = hiddenNerves;
    }

    public void setMultiSelfAttention(MultiSelfAttention multiSelfAttention) {
        this.multiSelfAttention = multiSelfAttention;
    }
}

