/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.StringReader;
import java.io.StringWriter;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.CostMatrix;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.rules.ZeroR;
import weka.core.BatchPredictor;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class CostSensitiveClassifier
extends RandomizableSingleClassifierEnhancer
implements OptionHandler,
Drawable,
BatchPredictor,
WeightedInstancesHandler {
    static final long serialVersionUID = -110658209263002404L;
    public static final int MATRIX_ON_DEMAND = 1;
    public static final int MATRIX_SUPPLIED = 2;
    public static final Tag[] TAGS_MATRIX_SOURCE = new Tag[]{new Tag(1, "Load cost matrix on demand"), new Tag(2, "Use explicit cost matrix")};
    protected int m_MatrixSource = 1;
    protected File m_OnDemandDirectory = new File(System.getProperty("user.dir"));
    protected String m_CostFile;
    protected CostMatrix m_CostMatrix = new CostMatrix(1);
    protected boolean m_MinimizeExpectedCost;

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.rules.ZeroR";
    }

    public CostSensitiveClassifier() {
        this.m_Classifier = new ZeroR();
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>(4);
        newVector.addElement(new Option("\tMinimize expected misclassification cost. Default is to\n\treweight training instances according to costs per class", "M", 0, "-M"));
        newVector.addElement(new Option("\tFile name of a cost matrix to use. If this is not supplied,\n\ta cost matrix will be loaded on demand. The name of the\n\ton-demand file is the relation name of the training data\n\tplus \".cost\", and the path to the on-demand file is\n\tspecified with the -N option.", "C", 1, "-C <cost file name>"));
        newVector.addElement(new Option("\tName of a directory to search for cost files when loading\n\tcosts on demand (default current directory).", "N", 1, "-N <directory>"));
        newVector.addElement(new Option("\tThe cost matrix in Matlab single line format.", "cost-matrix", 1, "-cost-matrix <matrix>"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String cost_matrix;
        this.setMinimizeExpectedCost(Utils.getFlag('M', options));
        String costFile = Utils.getOption('C', options);
        if (costFile.length() != 0) {
            try {
                this.setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(costFile))));
            }
            catch (Exception ex) {
                this.setCostMatrix(null);
            }
            this.setCostMatrixSource(new SelectedTag(2, TAGS_MATRIX_SOURCE));
            this.m_CostFile = costFile;
        } else {
            this.setCostMatrixSource(new SelectedTag(1, TAGS_MATRIX_SOURCE));
        }
        String demandDir = Utils.getOption('N', options);
        if (demandDir.length() != 0) {
            this.setOnDemandDirectory(new File(demandDir));
        }
        if ((cost_matrix = Utils.getOption("cost-matrix", options)).length() != 0) {
            StringWriter writer = new StringWriter();
            CostMatrix.parseMatlab(cost_matrix).write(writer);
            this.setCostMatrix(new CostMatrix(new StringReader(writer.toString())));
            this.setCostMatrixSource(new SelectedTag(2, TAGS_MATRIX_SOURCE));
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        if (this.m_MatrixSource == 2) {
            if (this.m_CostFile != null) {
                options.add("-C");
                options.add("" + this.m_CostFile);
            } else {
                options.add("-cost-matrix");
                options.add(this.getCostMatrix().toMatlab());
            }
        } else {
            options.add("-N");
            options.add("" + this.getOnDemandDirectory());
        }
        if (this.getMinimizeExpectedCost()) {
            options.add("-M");
        }
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    public String globalInfo() {
        return "A metaclassifier that makes its base classifier cost sensitive. Two methods can be used to introduce cost-sensitivity: reweighting training instances according to the total cost assigned to each class; or predicting the class with minimum expected misclassification cost (rather than the most likely class). Performance can often be improved by using a bagged classifier to improve the probability estimates of the base classifier. If the base classifier cannot handle instance weights, and the instance weights are not uniform, the data will be resampled with replacement based on the weights before being passed to the base classifier.";
    }

    public String costMatrixSourceTipText() {
        return "Sets where to get the cost matrix. The two options areto use the supplied explicit cost matrix (the setting of the costMatrix property), or to load a cost matrix from a file when required (this file will be loaded from the directory set by the onDemandDirectory property and will be named relation_name" + CostMatrix.FILE_EXTENSION + ").";
    }

    public SelectedTag getCostMatrixSource() {
        return new SelectedTag(this.m_MatrixSource, TAGS_MATRIX_SOURCE);
    }

    public void setCostMatrixSource(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_MATRIX_SOURCE) {
            this.m_MatrixSource = newMethod.getSelectedTag().getID();
        }
    }

    public String onDemandDirectoryTipText() {
        return "Sets the directory where cost files are loaded from. This option is used when the costMatrixSource is set to \"On Demand\".";
    }

    public File getOnDemandDirectory() {
        return this.m_OnDemandDirectory;
    }

    public void setOnDemandDirectory(File newDir) {
        this.m_OnDemandDirectory = newDir.isDirectory() ? newDir : new File(newDir.getParent());
        this.m_MatrixSource = 1;
    }

    public String minimizeExpectedCostTipText() {
        return "Sets whether the minimum expected cost criteria will be used. If this is false, the training data will be reweighted according to the costs assigned to each class. If true, the minimum expected cost criteria will be used.";
    }

    public boolean getMinimizeExpectedCost() {
        return this.m_MinimizeExpectedCost;
    }

    public void setMinimizeExpectedCost(boolean newMinimizeExpectedCost) {
        this.m_MinimizeExpectedCost = newMinimizeExpectedCost;
    }

    @Override
    protected String getClassifierSpec() {
        Classifier c = this.getClassifier();
        if (c instanceof OptionHandler) {
            return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)((Object)c)).getOptions());
        }
        return c.getClass().getName();
    }

    public String costMatrixTipText() {
        return "Sets the cost matrix explicitly. This matrix is used if the costMatrixSource property is set to \"Supplied\".";
    }

    public CostMatrix getCostMatrix() {
        return this.m_CostMatrix;
    }

    public void setCostMatrix(CostMatrix newCostMatrix) {
        this.m_CostMatrix = newCostMatrix;
        this.m_MatrixSource = 2;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.disableAllClassDependencies();
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        return result;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        if (this.m_Classifier == null) {
            throw new Exception("No base classifier has been set!");
        }
        if (this.m_MatrixSource == 1) {
            String costName = data.relationName() + CostMatrix.FILE_EXTENSION;
            File costFile = new File(this.getOnDemandDirectory(), costName);
            if (!costFile.exists()) {
                throw new Exception("On-demand cost file doesn't exist: " + costFile);
            }
            this.setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(costFile))));
        } else if (this.m_CostMatrix == null) {
            this.m_CostMatrix = new CostMatrix(data.numClasses());
            this.m_CostMatrix.readOldFormat(new BufferedReader(new FileReader(this.m_CostFile)));
        }
        if (!this.m_MinimizeExpectedCost) {
            Random random = null;
            if (!(this.m_Classifier instanceof WeightedInstancesHandler)) {
                random = new Random(this.m_Seed);
            }
            data = this.m_CostMatrix.applyCostMatrix(data, random);
        } else if (!data.allInstanceWeightsIdentical() && !(this.m_Classifier instanceof WeightedInstancesHandler)) {
            Random r = data.numInstances() > 0 ? data.getRandomNumberGenerator(this.getSeed()) : new Random(this.getSeed());
            data = data.resampleWithWeights(r);
        }
        this.m_Classifier.buildClassifier(data);
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (!this.m_MinimizeExpectedCost) {
            return this.m_Classifier.distributionForInstance(instance);
        }
        return this.convertDistribution(this.m_Classifier.distributionForInstance(instance), instance);
    }

    protected double[] convertDistribution(double[] pred, Instance instance) throws Exception {
        double[] costs = this.m_CostMatrix.expectedCosts(pred, instance);
        int classIndex = Utils.minIndex(costs);
        for (int i = 0; i < pred.length; ++i) {
            pred[i] = i == classIndex ? 1.0 : 0.0;
        }
        return pred;
    }

    @Override
    public double[][] distributionsForInstances(Instances insts) throws Exception {
        if (this.getClassifier() instanceof BatchPredictor) {
            double[][] dists = ((BatchPredictor)((Object)this.getClassifier())).distributionsForInstances(insts);
            if (!this.m_MinimizeExpectedCost) {
                return dists;
            }
            for (int i = 0; i < dists.length; ++i) {
                dists[i] = this.convertDistribution(dists[i], insts.instance(i));
            }
            return dists;
        }
        double[][] result = new double[insts.numInstances()][insts.numClasses()];
        for (int i = 0; i < insts.numInstances(); ++i) {
            result[i] = this.distributionForInstance(insts.instance(i));
        }
        return result;
    }

    @Override
    public String batchSizeTipText() {
        return "Batch size to use if base learner is a BatchPredictor";
    }

    @Override
    public void setBatchSize(String size) {
        if (this.getClassifier() instanceof BatchPredictor) {
            ((BatchPredictor)((Object)this.getClassifier())).setBatchSize(size);
        } else {
            super.setBatchSize(size);
        }
    }

    @Override
    public String getBatchSize() {
        if (this.getClassifier() instanceof BatchPredictor) {
            return ((BatchPredictor)((Object)this.getClassifier())).getBatchSize();
        }
        return super.getBatchSize();
    }

    @Override
    public boolean implementsMoreEfficientBatchPrediction() {
        if (!(this.getClassifier() instanceof BatchPredictor)) {
            return false;
        }
        return ((BatchPredictor)((Object)this.getClassifier())).implementsMoreEfficientBatchPrediction();
    }

    @Override
    public int graphType() {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable)((Object)this.m_Classifier)).graphType();
        }
        return 0;
    }

    @Override
    public String graph() throws Exception {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable)((Object)this.m_Classifier)).graph();
        }
        throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot be graphed");
    }

    public String toString() {
        if (this.m_Classifier == null) {
            return "CostSensitiveClassifier: No model built yet.";
        }
        String result = "CostSensitiveClassifier using ";
        result = this.m_MinimizeExpectedCost ? result + "minimized expected misclasification cost\n" : result + "reweighted training instances\n";
        result = result + "\n" + this.getClassifierSpec() + "\n\nClassifier Model\n" + this.m_Classifier.toString() + "\n\nCost Matrix\n" + this.m_CostMatrix.toString();
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 15478 $");
    }

    public static void main(String[] argv) {
        CostSensitiveClassifier.runClassifier(new CostSensitiveClassifier(), argv);
    }
}

