/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.SumLattice;
import cc.mallet.fst.SumLatticeFactory;
import cc.mallet.fst.Transducer;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.logging.Logger;

public class SumLatticeScaling
implements SumLattice {
    private static Logger logger = MalletLogger.getLogger(SumLatticeScaling.class.getName());
    protected static boolean saveXis = false;
    Sequence input;
    Sequence output;
    Transducer t;
    double totalWeight;
    LatticeNode[][] nodes;
    double[] alphaLogScaling;
    double[] betaLogScaling;
    double zLogScaling;
    int latticeLength;
    double[][] gammas;
    double[][][] xis;

    protected SumLatticeScaling() {
    }

    protected LatticeNode getLatticeNode(int ip, int stateIndex) {
        if (this.nodes[ip][stateIndex] == null) {
            this.nodes[ip][stateIndex] = new LatticeNode(ip, this.t.getState(stateIndex));
        }
        return this.nodes[ip][stateIndex];
    }

    public SumLatticeScaling(Transducer trans, Sequence input) {
        this(trans, input, null, null, saveXis, null);
    }

    public SumLatticeScaling(Transducer trans, Sequence input, boolean saveXis) {
        this(trans, input, null, null, saveXis, null);
    }

    public SumLatticeScaling(Transducer trans, Sequence input, Transducer.Incrementor incrementor) {
        this(trans, input, null, incrementor, saveXis, null);
    }

    public SumLatticeScaling(Transducer trans, Sequence input, Sequence output) {
        this(trans, input, output, null, saveXis, null);
    }

    public SumLatticeScaling(Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor) {
        this(trans, input, output, incrementor, saveXis, null);
    }

    public SumLatticeScaling(Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet) {
        this(trans, input, output, incrementor, saveXis, outputAlphabet);
    }

    public SumLatticeScaling(Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis) {
        this(trans, input, output, incrementor, saveXis, null);
    }

    public SumLatticeScaling(Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet) {
        int i;
        assert (output == null || input.size() == output.size());
        this.t = trans;
        this.input = input;
        this.output = output;
        this.latticeLength = input.size() + 1;
        int numStates = this.t.numStates();
        this.nodes = new LatticeNode[this.latticeLength][numStates];
        this.alphaLogScaling = new double[this.latticeLength];
        this.betaLogScaling = new double[this.latticeLength];
        this.gammas = new double[this.latticeLength][numStates];
        if (saveXis) {
            this.xis = new double[this.latticeLength][numStates][numStates];
        }
        double[][] outputCounts = null;
        if (outputAlphabet != null) {
            outputCounts = new double[this.latticeLength][outputAlphabet.size()];
        }
        int ip = 0;
        while (ip < this.latticeLength) {
            this.alphaLogScaling[ip] = 0.0;
            this.betaLogScaling[ip] = 0.0;
            i = 0;
            while (i < numStates) {
                this.gammas[ip][i] = Double.NEGATIVE_INFINITY;
                if (saveXis) {
                    int j = 0;
                    while (j < numStates) {
                        this.xis[ip][i][j] = Double.NEGATIVE_INFINITY;
                        ++j;
                    }
                }
                ++i;
            }
            ++ip;
        }
        logger.fine("Starting Foward pass");
        boolean atLeastOneInitialState = false;
        i = 0;
        while (i < numStates) {
            double initialWeight = this.t.getState(i).getInitialWeight();
            if (initialWeight > Double.NEGATIVE_INFINITY) {
                this.getLatticeNode((int)0, (int)i).alpha = Math.exp(initialWeight);
                atLeastOneInitialState = true;
            }
            ++i;
        }
        this.rescaleAlphas(0);
        if (!atLeastOneInitialState) {
            logger.warning("There are no starting states!");
        }
        int ip2 = 0;
        while (ip2 < this.latticeLength - 1) {
            int i2 = 0;
            while (i2 < numStates) {
                if (!this.isInvalidNode(ip2, i2)) {
                    Transducer.State s = this.t.getState(i2);
                    Transducer.TransitionIterator iter = s.transitionIterator(input, ip2, output, ip2);
                    while (iter.hasNext()) {
                        Transducer.State destination = iter.next();
                        LatticeNode destinationNode = this.getLatticeNode(ip2 + 1, destination.getIndex());
                        if (Double.isNaN(destinationNode.alpha)) {
                            destinationNode.alpha = 0.0;
                        }
                        destinationNode.output = iter.getOutput();
                        double transitionWeight = iter.getWeight();
                        destinationNode.alpha += this.nodes[ip2][i2].alpha * Math.exp(transitionWeight);
                    }
                }
                ++i2;
            }
            this.rescaleAlphas(ip2 + 1);
            ++ip2;
        }
        double Z = Double.NaN;
        int i3 = 0;
        while (i3 < numStates) {
            if (this.nodes[this.latticeLength - 1][i3] != null) {
                if (Double.isNaN(Z)) {
                    Z = 0.0;
                }
                Z += this.nodes[this.latticeLength - 1][i3].alpha * Math.exp(this.t.getState(i3).getFinalWeight());
            }
            ++i3;
        }
        this.zLogScaling = this.alphaLogScaling[this.latticeLength - 1];
        if (Double.isNaN(Z)) {
            this.totalWeight = Double.NEGATIVE_INFINITY;
            return;
        }
        this.totalWeight = Math.log(Z) + this.zLogScaling;
        i3 = 0;
        while (i3 < numStates) {
            if (this.nodes[this.latticeLength - 1][i3] != null) {
                Transducer.State s = this.t.getState(i3);
                this.nodes[this.latticeLength - 1][i3].beta = Math.exp(s.getFinalWeight());
                double gamma = this.nodes[this.latticeLength - 1][i3].alpha * this.nodes[this.latticeLength - 1][i3].beta / Z;
                this.gammas[this.latticeLength - 1][i3] = Math.log(gamma);
                if (incrementor != null) {
                    double p = gamma;
                    assert (p >= 0.0 && p <= 1.000001) : "p=" + p + ", gamma=" + this.gammas[this.latticeLength - 1][i3];
                    incrementor.incrementFinalState(s, p);
                }
            }
            ++i3;
        }
        this.rescaleBetas(this.latticeLength - 1);
        int ip3 = this.latticeLength - 2;
        while (ip3 >= 0) {
            int i4 = 0;
            while (i4 < numStates) {
                if (!this.isInvalidNode(ip3, i4)) {
                    Transducer.State s = this.t.getState(i4);
                    Transducer.TransitionIterator iter = s.transitionIterator(input, ip3, output, ip3);
                    double logScaling = this.alphaLogScaling[ip3] + this.betaLogScaling[ip3 + 1] - this.zLogScaling;
                    double pscaling = Math.exp(logScaling);
                    while (iter.hasNext()) {
                        Transducer.State destination = iter.next();
                        int j = destination.getIndex();
                        LatticeNode destinationNode = this.nodes[ip3 + 1][j];
                        if (destinationNode == null) continue;
                        double transitionWeight = iter.getWeight();
                        if (Double.isNaN(this.nodes[ip3][i4].beta)) {
                            this.nodes[ip3][i4].beta = 0.0;
                        }
                        double transitionProb = Math.exp(transitionWeight);
                        this.nodes[ip3][i4].beta += destinationNode.beta * transitionProb;
                        double xi = this.nodes[ip3][i4].alpha * transitionProb * this.nodes[ip3 + 1][j].beta / Z;
                        if (saveXis) {
                            this.xis[ip3][i4][j] = Math.log(xi) + logScaling;
                        }
                        if (incrementor == null && outputAlphabet == null) continue;
                        double p = xi * pscaling;
                        assert (p >= 0.0 && p <= 1.000001) : "p=" + p + ", xis[" + ip3 + "][" + i4 + "][" + j + "]=" + xi;
                        if (incrementor != null) {
                            incrementor.incrementTransition(iter, p);
                        }
                        if (outputAlphabet == null) continue;
                        int outputIndex = outputAlphabet.lookupIndex(iter.getOutput(), false);
                        assert (outputIndex >= 0);
                        double[] dArray = outputCounts[ip3];
                        int n = outputIndex;
                        dArray[n] = dArray[n] + p;
                    }
                    this.gammas[ip3][i4] = Math.log(this.nodes[ip3][i4].alpha * this.nodes[ip3][i4].beta / Z) + logScaling;
                }
                ++i4;
            }
            this.rescaleBetas(ip3);
            --ip3;
        }
        if (incrementor != null) {
            i3 = 0;
            while (i3 < numStates) {
                double p = Math.exp(this.gammas[0][i3]);
                assert (p >= 0.0 && p <= 1.000001) : "p=" + p;
                incrementor.incrementInitialState(this.t.getState(i3), p);
                ++i3;
            }
        }
    }

    private boolean isInvalidNode(int ip, int i) {
        return this.nodes[ip][i] == null || Double.isNaN(this.nodes[ip][i].alpha);
    }

    private void rescaleAlphas(int ip) {
        double sumAlpha = 0.0;
        int i = 0;
        while (i < this.t.numStates()) {
            if (!this.isInvalidNode(ip, i)) {
                sumAlpha += this.nodes[ip][i].alpha;
            }
            ++i;
        }
        assert (sumAlpha > 0.0) : "Invalid sum over alphas for ip=" + ip;
        this.alphaLogScaling[ip] = Math.log(sumAlpha) + (ip == 0 ? 0.0 : this.alphaLogScaling[ip - 1]);
        i = 0;
        while (i < this.t.numStates()) {
            if (!this.isInvalidNode(ip, i)) {
                this.nodes[ip][i].alpha /= sumAlpha;
            }
            ++i;
        }
    }

    private void rescaleBetas(int ip) {
        double sumBeta = 0.0;
        int i = 0;
        while (i < this.t.numStates()) {
            if (!this.isInvalidNode(ip, i)) {
                sumBeta += this.nodes[ip][i].beta;
            }
            ++i;
        }
        assert (sumBeta > 0.0) : "Invalid sum over betas for ip=" + ip;
        this.betaLogScaling[ip] = Math.log(sumBeta) + (ip == this.latticeLength - 1 ? 0.0 : this.betaLogScaling[ip + 1]);
        i = 0;
        while (i < this.t.numStates()) {
            if (!this.isInvalidNode(ip, i)) {
                this.nodes[ip][i].beta /= sumBeta;
            }
            ++i;
        }
    }

    @Override
    public double[][][] getXis() {
        return this.xis;
    }

    @Override
    public double[][] getGammas() {
        return this.gammas;
    }

    @Override
    public double getTotalWeight() {
        return this.totalWeight;
    }

    @Override
    public double getGammaWeight(int inputPosition, Transducer.State s) {
        return this.gammas[inputPosition][s.getIndex()];
    }

    public double getGammaWeight(int inputPosition, int stateIndex) {
        return this.gammas[inputPosition][stateIndex];
    }

    @Override
    public double getGammaProbability(int inputPosition, Transducer.State s) {
        return Math.exp(this.gammas[inputPosition][s.getIndex()]);
    }

    public double getGammaProbability(int inputPosition, int stateIndex) {
        return this.getGammaProbability(inputPosition, this.t.getState(stateIndex));
    }

    @Override
    public double getXiProbability(int ip, Transducer.State s1, Transducer.State s2) {
        return Math.exp(this.getXiWeight(ip, s1, s2));
    }

    @Override
    public double getXiWeight(int ip, Transducer.State s1, Transducer.State s2) {
        if (this.xis == null) {
            throw new IllegalStateException("xis were not saved.");
        }
        int i = s1.getIndex();
        int j = s2.getIndex();
        return this.xis[ip][i][j];
    }

    @Override
    public int length() {
        return this.latticeLength;
    }

    @Override
    public double getAlpha(int ip, Transducer.State s) {
        LatticeNode node = this.getLatticeNode(ip, s.getIndex());
        return node.alpha * Math.exp(this.alphaLogScaling[ip]);
    }

    @Override
    public double getBeta(int ip, Transducer.State s) {
        LatticeNode node = this.getLatticeNode(ip, s.getIndex());
        return node.beta * Math.exp(this.betaLogScaling[ip]);
    }

    @Override
    public LabelVector getLabelingAtPosition(int outputPosition) {
        throw new RuntimeException("Not implemented for SumLatticeScaling!");
    }

    @Override
    public Sequence getInput() {
        return this.input;
    }

    @Override
    public Transducer getTransducer() {
        return this.t;
    }

    public static class Factory
    extends SumLatticeFactory
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 1;

        @Override
        public SumLattice newSumLattice(Transducer trans, Sequence input, Sequence output, Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet) {
            return new SumLatticeScaling(trans, input, output, incrementor, saveXis, outputAlphabet);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(1);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
        }
    }

    protected class LatticeNode {
        int inputPosition;
        Transducer.State state;
        Object output;
        double alpha = Double.NaN;
        double beta = Double.NaN;

        LatticeNode(int inputPosition, Transducer.State state) {
            this.inputPosition = inputPosition;
            this.state = state;
        }
    }
}

