/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.pr;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Sequence;
import java.util.ArrayList;
import java.util.Iterator;

public class PRAuxiliaryModel
extends Transducer {
    private static final long serialVersionUID = 1L;
    private int numParameters;
    private double[][] parameters;
    private ArrayList<PRConstraint> constraints;
    private CRF baseModel;

    public PRAuxiliaryModel(CRF baseModel, ArrayList<PRConstraint> constraints) {
        this.baseModel = baseModel;
        this.constraints = constraints;
        int index = 0;
        this.parameters = new double[constraints.size()][];
        for (PRConstraint constraint : constraints) {
            this.parameters[index] = new double[constraint.numDimensions()];
            ++index;
            this.numParameters += constraint.numDimensions();
        }
    }

    private PRAuxiliaryModel(CRF baseModel, ArrayList<PRConstraint> constraints, double[][] parameters) {
        this.baseModel = baseModel;
        this.constraints = constraints;
        this.parameters = parameters;
        for (PRConstraint constraint : constraints) {
            this.numParameters += constraint.numDimensions();
        }
    }

    public PRAuxiliaryModel copy() {
        ArrayList<PRConstraint> copy = new ArrayList<PRConstraint>();
        for (PRConstraint constraint : this.constraints) {
            copy.add(constraint.copy());
        }
        return new PRAuxiliaryModel(this.baseModel, copy, this.parameters);
    }

    public void preProcess(int index, int position, Sequence input) {
        for (PRConstraint constraint : this.constraints) {
            constraint.preProcess((FeatureVector)input.get(position));
        }
    }

    public double getValue() {
        double value = 0.0;
        int index = 0;
        for (PRConstraint constraint : this.constraints) {
            value += constraint.getAuxiliaryValueContribution(this.parameters[index]);
            ++index;
        }
        return value;
    }

    public double getCompleteValueContribution() {
        double value = 0.0;
        int index = 0;
        for (PRConstraint constraint : this.constraints) {
            value += constraint.getCompleteValueContribution(this.parameters[index]);
            ++index;
        }
        return value;
    }

    public void getValueGradient(double[] gradient) {
        int index = 0;
        int start = 0;
        for (PRConstraint constraint : this.constraints) {
            double[] constraintGradient = new double[constraint.numDimensions()];
            constraint.getGradient(this.parameters[index], constraintGradient);
            System.arraycopy(constraintGradient, 0, gradient, start, constraintGradient.length);
            start += constraint.numDimensions();
            ++index;
        }
    }

    public double getWeight(int index, int position, Sequence input, Transducer.TransitionIterator iter) {
        double weight = 0.0;
        int si1 = iter.getSourceState().getIndex();
        int si2 = iter.getDestinationState().getIndex();
        int constrIndex = 0;
        for (PRConstraint constraint : this.constraints) {
            weight += constraint.getScore((FeatureVector)input.get(position), position, si1, si2, this.parameters[constrIndex]);
            ++constrIndex;
        }
        return weight;
    }

    public void incrementTransition(int index, int position, Sequence input, Transducer.TransitionIterator iter, double prob) {
        int si1 = iter.getSourceState().getIndex();
        int si2 = iter.getDestinationState().getIndex();
        for (PRConstraint constraint : this.constraints) {
            constraint.incrementExpectations((FeatureVector)input.get(position), position, si1, si2, prob);
        }
    }

    public void zeroExpectations() {
        for (PRConstraint constraint : this.constraints) {
            constraint.zeroExpectations();
        }
    }

    public int numParameters() {
        return this.numParameters;
    }

    public void getParameters(double[] params) {
        assert (params.length == this.numParameters);
        int start = 0;
        int i = 0;
        while (i < this.parameters.length) {
            System.arraycopy(this.parameters[i], 0, params, start, this.parameters[i].length);
            start += this.parameters[i].length;
            ++i;
        }
    }

    public double getParameter(int index) {
        assert (index > 0);
        int constrIndex = 0;
        for (PRConstraint constraint : this.constraints) {
            if (index < constraint.numDimensions()) {
                return this.parameters[constrIndex][index];
            }
            ++constrIndex;
            index -= constraint.numDimensions();
        }
        throw new RuntimeException("index not found: " + index);
    }

    public void setParameters(double[] params) {
        assert (params.length == this.numParameters);
        int start = 0;
        int i = 0;
        while (i < this.parameters.length) {
            System.arraycopy(params, start, this.parameters[i], 0, this.parameters[i].length);
            start += this.parameters[i].length;
            ++i;
        }
    }

    public void setParameter(int index, double value) {
        assert (index > 0);
        int constrIndex = 0;
        for (PRConstraint constraint : this.constraints) {
            if (index < constraint.numDimensions()) {
                this.parameters[constrIndex][index] = value;
                return;
            }
            ++constrIndex;
            index -= constraint.numDimensions();
        }
        throw new RuntimeException("index not found: " + index);
    }

    public int numConstraints() {
        return this.constraints.size();
    }

    public PRConstraint getConstraint(int index) {
        return this.constraints.get(index);
    }

    public CRF getBaseModel() {
        return this.baseModel;
    }

    @Override
    public int numStates() {
        return this.baseModel.numStates();
    }

    @Override
    public Transducer.State getState(int index) {
        return this.baseModel.getState(index);
    }

    @Override
    public Iterator initialStateIterator() {
        return this.baseModel.initialStateIterator();
    }
}

