/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.AdaBoost;
import cc.mallet.classify.Boostable;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.util.Random;
import java.util.logging.Logger;

public class AdaBoostTrainer
extends ClassifierTrainer<AdaBoost> {
    private static Logger logger = MalletLogger.getLogger(AdaBoostTrainer.class.getName());
    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    ClassifierTrainer weakLearner;
    int numRounds;
    AdaBoost classifier;

    @Override
    public AdaBoost getClassifier() {
        return this.classifier;
    }

    public AdaBoostTrainer(ClassifierTrainer weakLearner, int numRounds) {
        if (!(weakLearner instanceof Boostable)) {
            throw new IllegalArgumentException("weak learner not boostable");
        }
        if (numRounds <= 0) {
            throw new IllegalArgumentException("number of rounds must be positive");
        }
        this.weakLearner = weakLearner;
        this.numRounds = numRounds;
    }

    public AdaBoostTrainer(ClassifierTrainer weakLearner) {
        this(weakLearner, 100);
    }

    @Override
    public AdaBoost train(InstanceList trainingList) {
        FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
        if (selectedFeatures != null) {
            throw new UnsupportedOperationException("FeatureSelection not yet implemented.");
        }
        Random random = new Random();
        double w = 1.0 / (double)trainingList.size();
        InstanceList trainingInsts = new InstanceList(trainingList.getPipe(), trainingList.size());
        int i = 0;
        while (i < trainingList.size()) {
            trainingInsts.add((Instance)trainingList.get(i), w);
            ++i;
        }
        boolean[] correct = new boolean[trainingInsts.size()];
        int numClasses = trainingInsts.getTargetAlphabet().size();
        if (numClasses != 2) {
            logger.info("AdaBoostTrainer.train: WARNING: more than two classes");
        }
        Classifier[] weakLearners = new Classifier[this.numRounds];
        double[] alphas = new double[this.numRounds];
        InstanceList roundTrainingInsts = new InstanceList(trainingInsts.getPipe());
        int round = 0;
        while (round < this.numRounds) {
            double err;
            logger.info("===========  AdaBoostTrainer round " + (round + 1) + " begin");
            int resamplingIterations = 0;
            do {
                err = 0.0;
                roundTrainingInsts = trainingInsts.sampleWithInstanceWeights(random);
                weakLearners[round] = this.weakLearner.train(roundTrainingInsts);
                int i2 = 0;
                while (i2 < trainingInsts.size()) {
                    Instance inst = (Instance)trainingInsts.get(i2);
                    if (weakLearners[round].classify(inst).bestLabelIsCorrect()) {
                        correct[i2] = true;
                    } else {
                        correct[i2] = false;
                        err += trainingInsts.getInstanceWeight(i2);
                    }
                    ++i2;
                }
            } while (Maths.almostEquals(err, 0.0) && ++resamplingIterations < MAX_NUM_RESAMPLING_ITERATIONS);
            if (Maths.almostEquals(err, 0.0) || err > 0.5) {
                int numClassifiersToUse;
                logger.info("AdaBoostTrainer stopped at " + (round + 1) + " / " + this.numRounds + " rounds: numClasses=" + numClasses + " error=" + err);
                int n = numClassifiersToUse = round == 0 ? 1 : round;
                if (round == 0) {
                    alphas[0] = 1.0;
                }
                double[] betas = new double[numClassifiersToUse];
                Classifier[] weakClassifiers = new Classifier[numClassifiersToUse];
                System.arraycopy(alphas, 0, betas, 0, numClassifiersToUse);
                System.arraycopy(weakLearners, 0, weakClassifiers, 0, numClassifiersToUse);
                int i3 = 0;
                while (i3 < betas.length) {
                    logger.info("AdaBoostTrainer weight[weakLearner[" + i3 + "]]=" + betas[i3]);
                    ++i3;
                }
                return new AdaBoost(roundTrainingInsts.getPipe(), weakClassifiers, betas);
            }
            alphas[round] = Math.log((1.0 - err) / err);
            double reweightFactor = err / (1.0 - err);
            double sum = 0.0;
            int i4 = 0;
            while (i4 < trainingInsts.size()) {
                w = trainingInsts.getInstanceWeight(i4);
                if (correct[i4]) {
                    w *= reweightFactor;
                }
                trainingInsts.setInstanceWeight(i4, w);
                sum += w;
                ++i4;
            }
            i4 = 0;
            while (i4 < trainingInsts.size()) {
                trainingInsts.setInstanceWeight(i4, trainingInsts.getInstanceWeight(i4) / sum);
                ++i4;
            }
            logger.info("===========  AdaBoostTrainer round " + (round + 1) + " finished, weak classifier training error = " + err);
            ++round;
        }
        int i5 = 0;
        while (i5 < alphas.length) {
            logger.info("AdaBoostTrainer weight[weakLearner[" + i5 + "]]=" + alphas[i5]);
            ++i5;
        }
        this.classifier = new AdaBoost(roundTrainingInsts.getPipe(), weakLearners, alphas);
        return this.classifier;
    }
}

