/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.Winnow;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;

public class WinnowTrainer
extends ClassifierTrainer<Winnow> {
    static final double DEFAULT_ALPHA = 2.0;
    static final double DEFAULT_BETA = 2.0;
    static final double DEFAULT_NFACTOR = 0.5;
    double alpha;
    double beta;
    double theta;
    double nfactor;
    double[][] weights;
    Winnow classifier;

    public WinnowTrainer() {
        this(2.0, 2.0, 0.5);
    }

    public WinnowTrainer(double a, double b) {
        this(a, b, 0.5);
    }

    public WinnowTrainer(double a, double b, double nfact) {
        this.alpha = a;
        this.beta = b;
        this.nfactor = nfact;
    }

    @Override
    public Winnow getClassifier() {
        return this.classifier;
    }

    @Override
    public Winnow train(InstanceList trainingList) {
        FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
        if (selectedFeatures != null) {
            throw new UnsupportedOperationException("FeatureSelection not yet implemented.");
        }
        trainingList.getDataAlphabet().stopGrowth();
        trainingList.getTargetAlphabet().stopGrowth();
        Pipe dataPipe = trainingList.getPipe();
        Alphabet dict = trainingList.getDataAlphabet();
        int numLabels = trainingList.getTargetAlphabet().size();
        int numFeats = dict.size();
        this.theta = (double)numFeats * this.nfactor;
        this.weights = new double[numLabels][numFeats];
        int i = 0;
        while (i < numLabels) {
            int j = 0;
            while (j < numFeats) {
                this.weights[i][j] = 1.0;
                ++j;
            }
            ++i;
        }
        int ii = 0;
        while (ii < trainingList.size()) {
            Instance inst = (Instance)trainingList.get(ii);
            Labeling labeling = inst.getLabeling();
            FeatureVector fv = (FeatureVector)inst.getData();
            double[] results = new double[numLabels];
            int fvisize = fv.numLocations();
            int correctIndex = labeling.getBestIndex();
            int rpos = 0;
            while (rpos < numLabels) {
                results[rpos] = 0.0;
                ++rpos;
            }
            int fvi = 0;
            while (fvi < fvisize) {
                int fi = fv.indexAtLocation(fvi);
                int lpos = 0;
                while (lpos < numLabels) {
                    int n = lpos;
                    results[n] = results[n] + this.weights[lpos][fi];
                    ++lpos;
                }
                ++fvi;
            }
            int ri = 0;
            while (ri < numLabels) {
                if (results[ri] > this.theta) {
                    if (correctIndex != ri) {
                        this.demote(ri, fv);
                    }
                } else if (correctIndex == ri) {
                    this.promote(ri, fv);
                }
                ++ri;
            }
            ++ii;
        }
        this.classifier = new Winnow(dataPipe, this.weights, this.theta, numLabels, numFeats);
        return this.classifier;
    }

    private void promote(int lpos, FeatureVector fv) {
        int fvisize = fv.numLocations();
        int fvi = 0;
        while (fvi < fvisize) {
            int fi = fv.indexAtLocation(fvi);
            double[] dArray = this.weights[lpos];
            int n = fi;
            dArray[n] = dArray[n] * this.alpha;
            ++fvi;
        }
    }

    private void demote(int lpos, FeatureVector fv) {
        int fvisize = fv.numLocations();
        int fvi = 0;
        while (fvi < fvisize) {
            int fi = fv.indexAtLocation(fvi);
            double[] dArray = this.weights[lpos];
            int n = fi;
            dArray[n] = dArray[n] / this.beta;
            ++fvi;
        }
    }
}

