/*
 * Decompiled with CFR 0.152.
 */
package org.wlld.transFormer.nerve;

import java.util.ArrayList;
import java.util.List;
import org.wlld.i.OutBack;
import org.wlld.matrixTools.Matrix;
import org.wlld.matrixTools.MatrixOperation;
import org.wlld.transFormer.nerve.Nerve;
import org.wlld.transFormer.nerve.OutNerve;

public class SoftMax
extends Nerve {
    private final List<OutNerve> outNerves;
    private final boolean isShowLog;
    private final MatrixOperation matrixOperation = new MatrixOperation();

    public SoftMax(List<OutNerve> outNerves, boolean isShowLog, int sensoryNerveNub, int hiddenNerveNub, int outNerveNub) throws Exception {
        super(0, "softMax", 0.0, null, sensoryNerveNub, hiddenNerveNub, outNerveNub, null, 0, 0.0, 1);
        this.outNerves = outNerves;
        this.isShowLog = isShowLog;
    }

    @Override
    protected void toOut(long eventId, Matrix parameter, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
        boolean allReady = this.insertMatrixParameter(eventId, parameter);
        if (allReady) {
            Matrix feature = (Matrix)this.reMatrixFeatures.get(eventId);
            this.reMatrixFeatures.remove(eventId);
            int x = feature.getX();
            if (isStudy) {
                if (E.size() != x) {
                    throw new Exception("\u671f\u671b\u7684\u5e8f\u5217\u957f\u5ea6\u4e0e\u5b9e\u9645\u5e8f\u5217\u4e0d\u76f8\u7b49\uff01\u8bf7\u68c0\u67e5\u671f\u671bE\uff0c\u8865\u5145\u6f0f\u6389\u7684\u5e8f\u5217");
                }
                Matrix allError = null;
                for (int i = 0; i < x; ++i) {
                    Matrix row = feature.getRow(i);
                    Mes mes = this.softMax(true, row, false);
                    int key = E.get(i);
                    if (this.isShowLog) {
                        System.out.println("softMax==" + key + ",out==" + mes.poi + ",nerveId==" + mes.typeID);
                    }
                    Matrix errors = this.error(mes, key);
                    allError = i == 0 ? errors : this.matrixOperation.pushVector(allError, errors, true);
                }
                int size = this.outNerves.size();
                for (int i = 0; i < size; ++i) {
                    Matrix errorMatrix = allError.getColumn(i);
                    this.outNerves.get(i).getGBySoftMax(errorMatrix, eventId);
                }
            } else if (outBack != null) {
                Mes mes = this.softMax(false, feature.getRow(x - 1), outAllPro);
                outBack.getBack(mes.poi, mes.typeID, eventId);
                if (outAllPro) {
                    outBack.getSoftMaxBack(eventId, mes.softMax);
                }
            } else {
                throw new Exception("not find outBack");
            }
        }
    }

    private Matrix error(Mes mes, int key) throws Exception {
        int t = key - 1;
        List<Double> softMax = mes.softMax;
        Matrix matrix = new Matrix(1, softMax.size());
        for (int i = 0; i < softMax.size(); ++i) {
            double self = softMax.get(i);
            double myError = i != t ? -self : 1.0 - self;
            matrix.setNub(0, i, myError);
        }
        return matrix;
    }

    private Mes softMax(boolean isStudy, Matrix matrix, boolean outAllPro) throws Exception {
        double sigma = 0.0;
        int id = 0;
        double poi = 0.0;
        Mes mes = new Mes();
        int size = matrix.getY();
        for (int j = 0; j < size; ++j) {
            double value = matrix.getNumber(0, j);
            sigma = Math.exp(value) + sigma;
        }
        ArrayList<Double> softMax = new ArrayList<Double>();
        for (int i = 0; i < size; ++i) {
            double eSelf = Math.exp(matrix.getNumber(0, i));
            double value = eSelf / sigma;
            if (isStudy || outAllPro) {
                softMax.add(value);
            }
            if (!(value > poi)) continue;
            poi = value;
            id = i + 1;
        }
        mes.softMax = softMax;
        mes.typeID = id;
        mes.poi = poi;
        return mes;
    }

    static class Mes {
        int typeID;
        double poi;
        List<Double> softMax;

        Mes() {
        }
    }
}

