/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.pr.constraints;

import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;

public class OneLabelL2IndPRConstraints
implements PRConstraint {
    protected boolean normalized;
    protected int numDimensions;
    protected TIntObjectHashMap<OneLabelL2IndPRConstraint> constraints;
    protected StateLabelMap map;
    protected TIntArrayList cache;

    public OneLabelL2IndPRConstraints(boolean normalized) {
        this.normalized = normalized;
        this.numDimensions = 0;
        this.constraints = new TIntObjectHashMap();
        this.map = null;
        this.cache = new TIntArrayList();
    }

    protected OneLabelL2IndPRConstraints(TIntObjectHashMap<OneLabelL2IndPRConstraint> constraints, StateLabelMap map, boolean normalized) {
        this.normalized = normalized;
        this.numDimensions = 0;
        this.constraints = new TIntObjectHashMap();
        int[] nArray = constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int key = nArray[n2];
            this.constraints.put(key, constraints.get(key).copy());
            this.numDimensions += constraints.get(key).getNumConstrainedLabels();
            ++n2;
        }
        this.map = map;
        this.cache = new TIntArrayList();
    }

    @Override
    public PRConstraint copy() {
        return new OneLabelL2IndPRConstraints(this.constraints, this.map, this.normalized);
    }

    public void addConstraint(int fi, int li, double target, double weight) {
        if (!this.constraints.containsKey(fi)) {
            this.constraints.put(fi, new OneLabelL2IndPRConstraint());
        }
        this.constraints.get(fi).add(li, target, weight, this.numDimensions);
        ++this.numDimensions;
    }

    @Override
    public int numDimensions() {
        return this.numDimensions;
    }

    @Override
    public void setStateLabelMap(StateLabelMap map) {
        this.map = map;
    }

    @Override
    public void preProcess(FeatureVector fv) {
        this.cache.resetQuick();
        int loc = 0;
        while (loc < fv.numLocations()) {
            int fi = fv.indexAtLocation(loc);
            if (this.constraints.containsKey(fi)) {
                this.cache.add(fi);
            }
            ++loc;
        }
    }

    @Override
    public BitSet preProcess(InstanceList data) {
        int ii = 0;
        BitSet bitSet = new BitSet(data.size());
        for (Instance instance : data) {
            FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
            int ip = 0;
            while (ip < fvs.size()) {
                FeatureVector fv = fvs.get(ip);
                int loc = 0;
                while (loc < fv.numLocations()) {
                    int fi = fv.indexAtLocation(loc);
                    if (this.constraints.containsKey(fi)) {
                        this.constraints.get((int)fi).count += 1.0;
                        bitSet.set(ii);
                    }
                    ++loc;
                }
                ++ip;
            }
            ++ii;
        }
        return bitSet;
    }

    @Override
    public double getScore(FeatureVector input, int inputPosition, int srcIndex, int destIndex, double[] parameters) {
        double dot = 0.0;
        int li2 = this.map.getLabelIndex(destIndex);
        int i = 0;
        while (i < this.cache.size()) {
            int fi = this.cache.getQuick(i);
            OneLabelL2IndPRConstraint constraint = this.constraints.get(fi);
            dot += constraint.getScore(li2, parameters);
            ++i;
        }
        return dot;
    }

    @Override
    public void incrementExpectations(FeatureVector input, int inputPosition, int srcIndex, int destIndex, double prob) {
        int li2 = this.map.getLabelIndex(destIndex);
        int i = 0;
        while (i < this.cache.size()) {
            this.constraints.get(this.cache.getQuick(i)).incrementExpectation(li2, prob);
            ++i;
        }
    }

    @Override
    public void getExpectations(double[] expectations) {
        assert (expectations.length == this.numDimensions()) : String.valueOf(expectations.length) + " " + this.numDimensions();
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            this.constraints.get(fi).getExpectations(expectations);
            ++n2;
        }
    }

    @Override
    public void addExpectations(double[] expectations) {
        assert (expectations.length == this.numDimensions());
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            this.constraints.get(fi).addExpectations(expectations);
            ++n2;
        }
    }

    @Override
    public void zeroExpectations() {
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            this.constraints.get(fi).zeroExpectation();
            ++n2;
        }
    }

    @Override
    public double getAuxiliaryValueContribution(double[] parameters) {
        double value = 0.0;
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            OneLabelL2IndPRConstraint constraint = this.constraints.get(fi);
            value += constraint.getProjectionValueContrib(parameters);
            ++n2;
        }
        return value;
    }

    @Override
    public double getCompleteValueContribution(double[] parameters) {
        double value = 0.0;
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            OneLabelL2IndPRConstraint constraint = this.constraints.get(fi);
            value += constraint.getCompleteValueContrib();
            ++n2;
        }
        return value;
    }

    @Override
    public void getGradient(double[] parameters, double[] gradient) {
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            OneLabelL2IndPRConstraint constraint = this.constraints.get(fi);
            constraint.getGradient(parameters, gradient);
            ++n2;
        }
    }

    protected class OneLabelL2IndPRConstraint {
        protected int index = 0;
        protected double count = 0.0;
        protected ArrayList<Integer> labels = new ArrayList();
        protected ArrayList<Integer> paramIndices = new ArrayList();
        protected ArrayList<Double> targets = new ArrayList();
        protected ArrayList<Double> weights = new ArrayList();
        protected HashMap<Integer, Integer> labelMap = new HashMap();
        protected double[] expectation;

        public OneLabelL2IndPRConstraint copy() {
            OneLabelL2IndPRConstraint copy = new OneLabelL2IndPRConstraint();
            copy.index = this.index;
            copy.count = this.count;
            copy.labels = this.labels;
            copy.paramIndices = this.paramIndices;
            copy.targets = this.targets;
            copy.weights = this.weights;
            copy.labelMap = this.labelMap;
            copy.expectation = new double[this.index];
            return copy;
        }

        public void add(int label, double target, double weight, int paramIndex) {
            this.targets.add(target);
            this.weights.add(weight);
            this.labels.add(label);
            this.paramIndices.add(paramIndex);
            this.labelMap.put(label, this.index);
            ++this.index;
        }

        public void zeroExpectation() {
            this.expectation = new double[this.labels.size()];
        }

        public void getExpectations(double[] expectations) {
            int i = 0;
            while (i < this.paramIndices.size()) {
                expectations[this.paramIndices.get((int)i).intValue()] = this.expectation[i];
                ++i;
            }
        }

        public void addExpectations(double[] expectations) {
            int i = 0;
            while (i < this.paramIndices.size()) {
                int n = i;
                this.expectation[n] = this.expectation[n] + expectations[this.paramIndices.get(i)];
                ++i;
            }
        }

        public void incrementExpectation(int li, double value) {
            if (this.labelMap.containsKey(li)) {
                int i;
                int n = i = this.labelMap.get(li).intValue();
                this.expectation[n] = this.expectation[n] + value;
            }
        }

        public double getScore(int li, double[] parameters) {
            if (this.labelMap.containsKey(li)) {
                int i = this.labelMap.get(li);
                if (OneLabelL2IndPRConstraints.this.normalized) {
                    return parameters[this.paramIndices.get(i)] / this.count;
                }
                return parameters[this.paramIndices.get(i)];
            }
            return 0.0;
        }

        public double getProjectionValueContrib(double[] parameters) {
            double value = 0.0;
            int i = 0;
            while (i < this.paramIndices.size()) {
                double param = parameters[this.paramIndices.get(i)];
                value += this.targets.get(i) * param - param * param / (2.0 * this.weights.get(i));
                ++i;
            }
            return value;
        }

        public double getCompleteValueContrib() {
            double value = 0.0;
            int i = 0;
            while (i < this.paramIndices.size()) {
                value = OneLabelL2IndPRConstraints.this.normalized ? (value += this.weights.get(i) * Math.pow(this.targets.get(i) - this.expectation[i] / this.count, 2.0) / 2.0) : (value += this.weights.get(i) * Math.pow(this.targets.get(i) - this.expectation[i], 2.0) / 2.0);
                ++i;
            }
            return value;
        }

        public void getGradient(double[] parameters, double[] gradient) {
            int i = 0;
            while (i < this.paramIndices.size()) {
                int pi = this.paramIndices.get(i);
                if (OneLabelL2IndPRConstraints.this.normalized) {
                    int n = pi;
                    gradient[n] = gradient[n] + (this.targets.get(i) - this.expectation[i] / this.count - parameters[pi] / this.weights.get(i));
                } else {
                    int n = pi;
                    gradient[n] = gradient[n] + (this.targets.get(i) - this.expectation[i] - parameters[pi] / this.weights.get(i));
                }
                ++i;
            }
        }

        public int getNumConstrainedLabels() {
            return this.index;
        }
    }
}

