/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.FeatureConstraintUtil;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntOptimizableByLabelDistribution;
import cc.mallet.classify.PRAuxClassifier;
import cc.mallet.classify.PRAuxClassifierOptimizable;
import cc.mallet.classify.constraints.pr.MaxEntL2FLPRConstraints;
import cc.mallet.classify.constraints.pr.MaxEntPRConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.NullLabel;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.logging.Logger;

public class MaxEntPRTrainer
extends ClassifierTrainer<MaxEnt>
implements ClassifierTrainer.ByOptimization<MaxEnt> {
    private static Logger logger = MalletLogger.getLogger(MaxEntPRTrainer.class.getName());
    private boolean normalize = true;
    private boolean useValues = false;
    private int minIterations = 10;
    private int maxIterations = 500;
    private double qGPV;
    private String constraintsFile;
    private boolean converged = false;
    private int numIterations = 0;
    private double tolerance = 0.001;
    private double pGPV;
    private ArrayList<MaxEntPRConstraint> constraints;
    private MaxEnt p;
    private PRAuxClassifier q;

    public MaxEntPRTrainer() {
    }

    public MaxEntPRTrainer(ArrayList<MaxEntPRConstraint> constraints) {
        this.constraints = constraints;
    }

    public void setPGaussianPriorVariance(double pGPV) {
        this.pGPV = pGPV;
    }

    public void setQGaussianPriorVariance(double qGPV) {
        this.qGPV = qGPV;
    }

    public void setConstraintsFile(String filename) {
        this.constraintsFile = filename;
    }

    public void setUseValues(boolean flag) {
        this.useValues = flag;
    }

    public void setMinIterations(int minIterations) {
        this.minIterations = minIterations;
    }

    public void setMaxIterations(int minIterations) {
        this.maxIterations = minIterations;
    }

    public void setNormalize(boolean normalize) {
        this.normalize = normalize;
    }

    @Override
    public Optimizer getOptimizer() {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public int getIteration() {
        return this.numIterations;
    }

    @Override
    public boolean isFinishedTraining() {
        return this.converged;
    }

    @Override
    public MaxEnt getClassifier() {
        return this.p;
    }

    @Override
    public MaxEnt train(InstanceList trainingSet) {
        return this.train(trainingSet, this.maxIterations);
    }

    @Override
    public MaxEnt train(InstanceList trainingSet, int maxIterations) {
        return this.train(trainingSet, Math.min(maxIterations, this.minIterations), maxIterations);
    }

    public MaxEnt train(InstanceList data, int minIterations, int maxIterations) {
        if (this.constraints == null && this.constraintsFile != null) {
            HashMap<Integer, double[]> constraintsMap = FeatureConstraintUtil.readConstraintsFromFile(this.constraintsFile, data);
            logger.info("number of constraints: " + constraintsMap.size());
            this.constraints = new ArrayList();
            MaxEntL2FLPRConstraints prConstraints = new MaxEntL2FLPRConstraints(data.getDataAlphabet().size(), data.getTargetAlphabet().size(), this.useValues, this.normalize);
            for (int fi : constraintsMap.keySet()) {
                prConstraints.addConstraint(fi, constraintsMap.get(fi), this.qGPV);
            }
            this.constraints.add(prConstraints);
        }
        BitSet instancesWithConstraints = new BitSet(data.size());
        for (MaxEntPRConstraint constraint : this.constraints) {
            BitSet bitset = constraint.preProcess(data);
            instancesWithConstraints.or(bitset);
        }
        InstanceList unlabeled = data.cloneEmpty();
        int ii = 0;
        while (ii < data.size()) {
            if (instancesWithConstraints.get(ii)) {
                boolean noLabel;
                boolean bl = noLabel = ((Instance)data.get(ii)).getTarget() == null;
                if (noLabel) {
                    ((Instance)data.get(ii)).unLock();
                    ((Instance)data.get(ii)).setTarget(new NullLabel((LabelAlphabet)data.getTargetAlphabet()));
                }
                unlabeled.add((Instance)data.get(ii));
            }
            ++ii;
        }
        int numFeatures = unlabeled.getDataAlphabet().size();
        int numParameters = (numFeatures + 1) * unlabeled.getTargetAlphabet().size();
        if (this.p == null) {
            this.p = new MaxEnt(unlabeled.getPipe(), new double[numParameters]);
        }
        this.q = new PRAuxClassifier(unlabeled.getPipe(), this.constraints);
        double oldValue = -1.7976931348623157E308;
        this.numIterations = 0;
        while (this.numIterations < maxIterations) {
            double[][] base = this.optimizeQ(unlabeled, this.p, this.numIterations == 0);
            double value = this.optimizePAndComputeValue(unlabeled, this.q, base, this.pGPV);
            logger.info("iteration " + this.numIterations + " total value " + value);
            if (this.numIterations >= minIterations - 1 && 2.0 * Math.abs(value - oldValue) <= this.tolerance * (Math.abs(value) + Math.abs(oldValue) + 1.0E-5)) {
                logger.info("PR value difference below tolerance (oldValue: " + oldValue + " newValue: " + value + ")");
                this.converged = true;
                break;
            }
            oldValue = value;
            ++this.numIterations;
        }
        return this.p;
    }

    private double optimizePAndComputeValue(InstanceList data, PRAuxClassifier q, double[][] base, double pGPV) {
        InstanceList dataLabeled = data.cloneEmpty();
        double entropy = 0.0;
        int numLabels = data.getTargetAlphabet().size();
        int ii = 0;
        while (ii < data.size()) {
            double[] scores = new double[numLabels];
            q.getClassificationScores((Instance)data.get(ii), scores);
            int li = 0;
            while (li < numLabels) {
                if (base != null && base[ii][li] == 0.0) {
                    scores[li] = Double.NEGATIVE_INFINITY;
                } else if (base != null) {
                    double logP = Math.log(base[ii][li]);
                    int n = li;
                    scores[n] = scores[n] + logP;
                }
                ++li;
            }
            MatrixOps.expNormalize(scores);
            entropy += Maths.getEntropy(scores);
            LabelVector lv = new LabelVector((LabelAlphabet)data.getTargetAlphabet(), scores);
            Instance instance = new Instance(((Instance)data.get(ii)).getData(), lv, null, null);
            dataLabeled.add(instance);
            ++ii;
        }
        MaxEntOptimizableByLabelDistribution opt = new MaxEntOptimizableByLabelDistribution(dataLabeled, this.p);
        opt.setGaussianPriorVariance(pGPV);
        LimitedMemoryBFGS bfgs = new LimitedMemoryBFGS(opt);
        try {
            bfgs.optimize();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        bfgs.reset();
        try {
            bfgs.optimize();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        double value = 0.0;
        for (MaxEntPRConstraint constraint : q.getConstraintFeatures()) {
            value += constraint.getCompleteValueContribution();
        }
        return value += entropy + opt.getValue();
    }

    private double[][] optimizeQ(InstanceList data, Classifier p, boolean firstIter) {
        double[][] base;
        int numLabels = data.getTargetAlphabet().size();
        if (firstIter) {
            base = null;
        } else {
            base = new double[data.size()][numLabels];
            int ii = 0;
            while (ii < data.size()) {
                p.classify((Instance)data.get(ii)).getLabelVector().addTo(base[ii]);
                ++ii;
            }
        }
        PRAuxClassifierOptimizable optimizable = new PRAuxClassifierOptimizable(data, base, this.q);
        LimitedMemoryBFGS bfgs = new LimitedMemoryBFGS(optimizable);
        try {
            bfgs.optimize();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        bfgs.reset();
        try {
            bfgs.optimize();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return base;
    }
}

