/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

public class BalancedWinnow
extends Classifier
implements Serializable {
    double[][] m_weights;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;

    public BalancedWinnow(Pipe dataPipe, double[][] weights) {
        super(dataPipe);
        this.m_weights = new double[weights.length][weights[0].length];
        int i = 0;
        while (i < weights.length) {
            int j = 0;
            while (j < weights[0].length) {
                this.m_weights[i][j] = weights[i][j];
                ++j;
            }
            ++i;
        }
    }

    public double[][] getWeights() {
        int numCols = this.m_weights[0].length;
        double[][] ret = new double[this.m_weights.length][numCols];
        int i = 0;
        while (i < ret.length) {
            System.arraycopy(this.m_weights[i], 0, ret[i], 0, numCols);
            ++i;
        }
        return ret;
    }

    @Override
    public Classification classify(Instance instance) {
        int numClasses = this.getLabelAlphabet().size();
        int numFeats = this.getAlphabet().size();
        double[] scores = new double[numClasses];
        FeatureVector fv = (FeatureVector)instance.getData();
        assert (this.instancePipe == null || fv.getAlphabet() == this.instancePipe.getDataAlphabet());
        int fvisize = fv.numLocations();
        double sum = 0.0;
        int ci = 0;
        while (ci < numClasses) {
            int fvi = 0;
            while (fvi < fvisize) {
                int fi = fv.indexAtLocation(fvi);
                double vi = fv.valueAtLocation(fvi);
                if (this.m_weights[ci].length > fi) {
                    int n = ci;
                    scores[n] = scores[n] + vi * this.m_weights[ci][fi];
                    sum += vi * this.m_weights[ci][fi];
                }
                ++fvi;
            }
            int n = ci;
            scores[n] = scores[n] + this.m_weights[ci][numFeats];
            sum += this.m_weights[ci][numFeats];
            ++ci;
        }
        MatrixOps.timesEquals(scores, 1.0 / sum);
        return new Classification(instance, this, new LabelVector(this.getLabelAlphabet(), scores));
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(1);
        out.writeObject(this.getInstancePipe());
        out.writeObject(this.m_weights);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        if (version != 1) {
            throw new ClassNotFoundException("Mismatched BalancedWinnow versions: wanted 1, got " + version);
        }
        this.instancePipe = (Pipe)in.readObject();
        this.m_weights = (double[][])in.readObject();
    }
}

