/*
 * Decompiled with CFR 0.152.
 */
package edu.umass.cs.mallet.base.types;

import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.Dirichlet;
import edu.umass.cs.mallet.base.types.FeatureSequence;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.util.Random;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

public class Multinomial
extends FeatureVector {
    private static final long serialVersionUID = 1L;

    private static double[] getValues(double[] probabilities, Alphabet dictionary, int size, boolean copy, boolean checkSum) {
        double[] values;
        assert (dictionary == null || dictionary.size() >= size);
        if (copy) {
            values = new double[dictionary == null ? size : dictionary.size()];
            System.arraycopy(probabilities, 0, values, 0, size);
        } else {
            assert (dictionary == null || dictionary.size() == probabilities.length);
            values = probabilities;
        }
        if (checkSum) {
            double sum = 0.0;
            int i = 0;
            while (i < values.length) {
                sum += values[i];
                ++i;
            }
            if (Math.abs(sum - 1.0) > 0.9999) {
                throw new IllegalArgumentException("Probabilities sum to " + sum + ", not to one.");
            }
        }
        return values;
    }

    protected Multinomial(double[] probabilities, Alphabet dictionary, int size, boolean copy, boolean checkSum) {
        super(dictionary, Multinomial.getValues(probabilities, dictionary, size, copy, checkSum));
    }

    public Multinomial(double[] probabilities, Alphabet dictionary) {
        this(probabilities, dictionary, dictionary.size(), true, true);
    }

    public Multinomial(double[] probabilities, int size) {
        this(probabilities, null, size, true, true);
    }

    public Multinomial(double[] probabilities) {
        this(probabilities, null, probabilities.length, true, true);
    }

    public int size() {
        return this.values.length;
    }

    public double probability(int featureIndex) {
        return this.values[featureIndex];
    }

    public double probability(Object key) {
        if (this.dictionary == null) {
            throw new IllegalStateException("This Multinomial has no dictionary.");
        }
        return this.probability(this.dictionary.lookupIndex(key));
    }

    public double logProbability(int featureIndex) {
        return Math.log(this.values[featureIndex]);
    }

    public double logProbability(Object key) {
        if (this.dictionary == null) {
            throw new IllegalStateException("This Multinomial has no dictionary.");
        }
        return this.logProbability(this.dictionary.lookupIndex(key));
    }

    public Alphabet getAlphabet() {
        return this.dictionary;
    }

    public void addProbabilitiesTo(double[] vector) {
        int i = 0;
        while (i < this.values.length) {
            int n = i;
            vector[n] = vector[n] + this.values[i];
            ++i;
        }
    }

    public int randomIndex(Random r) {
        double f = r.nextUniform();
        double sum = 0.0;
        int i = 0;
        while (i < this.values.length) {
            if ((sum += this.values[i]) >= f) break;
            ++i;
        }
        assert (sum >= f);
        return i;
    }

    public Object randomObject(Random r) {
        if (this.dictionary == null) {
            throw new IllegalStateException("This Multinomial has no dictionary.");
        }
        return this.dictionary.lookupObject(this.randomIndex(r));
    }

    public FeatureSequence randomFeatureSequence(Random r, int length) {
        if (!(this.dictionary instanceof Alphabet)) {
            throw new UnsupportedOperationException("Multinomial's dictionary must be a Alphabet");
        }
        FeatureSequence fs = new FeatureSequence(this.dictionary, length);
        while (length-- > 0) {
            fs.add(this.randomIndex(r));
        }
        return fs;
    }

    public FeatureVector randomFeatureVector(Random r, int size) {
        return new FeatureVector(this.randomFeatureSequence(r, size));
    }

    public static class Logged
    extends Multinomial {
        private static final long serialVersionUID = 1L;
        static final /* synthetic */ boolean $assertionsDisabled;

        static {
            $assertionsDisabled = !Logged.class.desiredAssertionStatus();
        }

        public Logged(double[] probabilities, Alphabet dictionary, int size, boolean areLoggedAlready) {
            super(probabilities, dictionary, size, true, !areLoggedAlready);
            if (!$assertionsDisabled && dictionary != null && dictionary.size() != size) {
                throw new AssertionError();
            }
            if (!areLoggedAlready) {
                int i = 0;
                while (i < size) {
                    this.values[i] = Math.log(this.values[i]);
                    ++i;
                }
            }
        }

        public Logged(double[] probabilities, Alphabet dictionary, boolean areLoggedAlready) {
            this(probabilities, dictionary, dictionary == null ? probabilities.length : dictionary.size(), areLoggedAlready);
        }

        public Logged(double[] probabilities, Alphabet dictionary, int size) {
            this(probabilities, dictionary, size, false);
        }

        public Logged(double[] probabilities, Alphabet dictionary) {
            this(probabilities, dictionary, dictionary.size(), false);
        }

        public Logged(Multinomial m) {
            this(m.values, m.dictionary, false);
        }

        public Logged(double[] probabilities) {
            this(probabilities, null, false);
        }

        public double probability(int featureIndex) {
            return Math.exp(this.values[featureIndex]);
        }

        public double logProbability(int featureIndex) {
            return this.values[featureIndex];
        }

        public void addProbabilities(double[] vector) {
            throw new UnsupportedOperationException("Not implemented.");
        }

        public void addLogProbabilities(double[] vector) {
            int i = 0;
            while (i < this.values.length) {
                int n = i;
                vector[n] = vector[n] + this.values[i];
                ++i;
            }
            i = this.values.length;
            while (i < vector.length) {
                vector[i] = Double.NEGATIVE_INFINITY;
                ++i;
            }
        }
    }

    public static abstract class Estimator
    implements Cloneable,
    Serializable {
        Alphabet dictionary;
        double[] counts;
        int size;
        static final int minCapacity = 16;
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 1;
        static final /* synthetic */ boolean $assertionsDisabled;

        static {
            $assertionsDisabled = !Estimator.class.desiredAssertionStatus();
        }

        protected Estimator(double[] counts, int size, Alphabet dictionary) {
            this.counts = counts;
            this.size = size;
            this.dictionary = dictionary;
        }

        public Estimator(double[] counts, Alphabet dictionary) {
            this(counts, dictionary.size(), dictionary);
        }

        public Estimator() {
            this(new double[16], 0, null);
        }

        public Estimator(int size) {
            this(new double[size > 16 ? size : 16], size, null);
        }

        public Estimator(Alphabet dictionary) {
            this(new double[dictionary.size()], dictionary.size(), dictionary);
        }

        public void setAlphabet(Alphabet d) {
            this.size = d.size();
            this.counts = new double[this.size];
            this.dictionary = d;
        }

        public int size() {
            return this.dictionary == null ? this.size : this.dictionary.size();
        }

        protected void ensureCapacity(int index) {
            if (index > this.size) {
                this.size = index;
            }
            if (this.counts.length <= index) {
                int newLength = this.counts.length < 16 ? 16 : this.counts.length;
                while (newLength <= index) {
                    newLength *= 2;
                }
                double[] newCounts = new double[newLength];
                System.arraycopy(this.counts, 0, newCounts, 0, this.counts.length);
                this.counts = newCounts;
            }
        }

        public void reset() {
            int i = 0;
            while (i < this.counts.length) {
                this.counts[i] = 0.0;
                ++i;
            }
        }

        private void setCounts(double[] counts) {
            if (!$assertionsDisabled && this.dictionary != null && counts.length > this.size()) {
                throw new AssertionError();
            }
            this.counts = counts;
        }

        public void increment(int index, double count) {
            this.ensureCapacity(index);
            int n = index;
            this.counts[n] = this.counts[n] + count;
            if (this.size < index + 1) {
                this.size = index + 1;
            }
        }

        public void increment(String key, double count) {
            this.increment(this.dictionary.lookupIndex(key), count);
        }

        public void increment(FeatureSequence fs, double scale) {
            if (fs.getAlphabet() != this.dictionary) {
                throw new IllegalArgumentException("Vocabularies don't match.");
            }
            int fsi = 0;
            while (fsi < fs.size()) {
                this.increment(fs.getIndexAtPosition(fsi), scale);
                ++fsi;
            }
        }

        public void increment(FeatureSequence fs) {
            this.increment(fs, 1.0);
        }

        public void increment(FeatureVector fv, double scale) {
            if (fv.getAlphabet() != this.dictionary) {
                throw new IllegalArgumentException("Vocabularies don't match.");
            }
            int fvi = 0;
            while (fvi < fv.numLocations()) {
                this.increment(fv.indexAtLocation(fvi), scale);
                ++fvi;
            }
        }

        public void increment(FeatureVector fv) {
            this.increment(fv, 1.0);
        }

        public double getCount(int index) {
            return this.counts[index];
        }

        public Object clone() {
            try {
                return super.clone();
            }
            catch (CloneNotSupportedException e) {
                return null;
            }
        }

        public void print() {
            System.out.println("Multinomial.Estimator");
            int i = 0;
            while (i < this.size) {
                System.out.println("counts[" + i + "] = " + this.counts[i]);
                ++i;
            }
        }

        public abstract Multinomial estimate();

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(1);
            out.writeObject(this.dictionary);
            out.writeObject(this.counts);
            out.writeInt(this.size);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
            if (version != 1) {
                throw new ClassNotFoundException("Mismatched Multionmial.Estimator versions: wanted 1, got " + version);
            }
            this.dictionary = (Alphabet)in.readObject();
            this.counts = (double[])in.readObject();
            this.size = in.readInt();
        }
    }

    public static class MEstimator
    extends Estimator {
        double m;
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 1;

        public MEstimator(Alphabet dictionary, double m) {
            super(dictionary);
            this.m = m;
        }

        public MEstimator(int size, double m) {
            super(size);
            this.m = m;
        }

        public MEstimator(double m) {
            this.m = m;
        }

        public Multinomial estimate() {
            double[] pr = new double[this.dictionary == null ? this.size : this.dictionary.size()];
            if (this.dictionary != null) {
                this.ensureCapacity(this.dictionary.size() - 1);
            }
            double sum = 0.0;
            int i = 0;
            while (i < pr.length) {
                pr[i] = this.counts[i] + this.m;
                sum += pr[i];
                ++i;
            }
            i = 0;
            while (i < pr.length) {
                int n = i++;
                pr[n] = pr[n] / sum;
            }
            return new Multinomial(pr, this.dictionary, this.size, false, false);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(1);
            out.writeDouble(this.m);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
            if (version != 1) {
                throw new ClassNotFoundException("Mismatched Multinomial.MEstimator versions: wanted 1, got " + version);
            }
            this.m = in.readDouble();
        }
    }

    public static class MLEstimator
    extends MEstimator {
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 1;

        public MLEstimator() {
            super(0.0);
        }

        public MLEstimator(int size) {
            super(size, 0.0);
        }

        public MLEstimator(Alphabet dictionary) {
            super(dictionary, 0.0);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(1);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
            if (version != 1) {
                throw new ClassNotFoundException("Mismatched Multinomial.MLEstimator versions: wanted 1, got " + version);
            }
        }
    }

    public static class LaplaceEstimator
    extends MEstimator {
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 1;

        public LaplaceEstimator() {
            super(1.0);
        }

        public LaplaceEstimator(int size) {
            super(size, 1.0);
        }

        public LaplaceEstimator(Alphabet dictionary) {
            super(dictionary, 1.0);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(1);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
            if (version != 1) {
                throw new ClassNotFoundException("Mismatched Multinomial.LaplaceEstimator versions: wanted 1, got " + version);
            }
        }
    }

    public static class MAPEstimator
    extends Estimator {
        Dirichlet prior;
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 1;

        public MAPEstimator(Dirichlet d) {
            super(d.size());
            this.prior = d;
        }

        public Multinomial estimate() {
            return null;
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(1);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
            if (version != 1) {
                throw new ClassNotFoundException("Mismatched Multinomial.MAPEstimator versions: wanted 1, got " + version);
            }
        }
    }
}

