/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.nerveEntity;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.dromara.i.OutBack;
import org.dromara.nerveEntity.Nerve;
import org.dromara.nerveEntity.OutNerve;

public class SoftMax
extends Nerve {
    private final List<OutNerve> outNerves;
    private final boolean isShowLog;

    public SoftMax(int upNub, boolean isDynamic, List<OutNerve> outNerves, boolean isShowLog, int coreNumber) throws Exception {
        super(0, upNub, "softMax", 0, 0.0, false, null, isDynamic, 0, 0.0, 0, 0, 0, 0, 0, coreNumber);
        this.outNerves = outNerves;
        this.isShowLog = isShowLog;
    }

    @Override
    protected void input(long eventId, double parameter, boolean isStudy, Map<Integer, Double> E, OutBack outBack) throws Exception {
        boolean allReady = this.insertParameter(eventId, parameter);
        if (allReady) {
            Mes mes = this.softMax(eventId, isStudy);
            int key = 0;
            if (isStudy) {
                for (Map.Entry<Integer, Double> entry : E.entrySet()) {
                    if (!(entry.getValue() > 0.9)) continue;
                    key = entry.getKey();
                    break;
                }
                if (this.isShowLog) {
                    System.out.println("softMax==" + key + ",out==" + mes.poi + ",nerveId==" + mes.typeID);
                }
                List<Double> errors = this.error(mes, key);
                this.features.remove(eventId);
                int size = this.outNerves.size();
                for (int i = 0; i < size; ++i) {
                    this.outNerves.get(i).getGBySoftMax(errors.get(i), eventId);
                }
            } else {
                this.destoryParameter(eventId);
                if (outBack != null) {
                    outBack.getBack(mes.poi, mes.typeID, eventId);
                    outBack.getSoftMaxBack(eventId, mes.softMax);
                } else {
                    throw new Exception("not find outBack");
                }
            }
        }
    }

    private List<Double> error(Mes mes, int key) {
        int t = key - 1;
        List<Double> softMax = mes.softMax;
        ArrayList<Double> error = new ArrayList<Double>();
        for (int i = 0; i < softMax.size(); ++i) {
            double self = softMax.get(i);
            double myError = i != t ? -self : 1.0 - self;
            error.add(myError);
        }
        return error;
    }

    private Mes softMax(long eventId, boolean isStudy) {
        double sigma = 0.0;
        int id = 0;
        double poi = 0.0;
        Mes mes = new Mes();
        List featuresList = (List)this.features.get(eventId);
        Iterator iterator = featuresList.iterator();
        while (iterator.hasNext()) {
            double value = (Double)iterator.next();
            sigma = Math.exp(value) + sigma;
        }
        ArrayList<Double> softMax = new ArrayList<Double>();
        for (int i = 0; i < featuresList.size(); ++i) {
            double eSelf = Math.exp((Double)featuresList.get(i));
            double value = eSelf / sigma;
            softMax.add(value);
            if (!(value > poi)) continue;
            poi = value;
            id = i + 1;
        }
        mes.softMax = softMax;
        mes.typeID = id;
        mes.poi = poi;
        return mes;
    }

    static class Mes {
        int typeID;
        double poi;
        List<Double> softMax;

        Mes() {
        }
    }
}

