/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.supervised.attribute;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.misc.InputMappedClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.AbstractInstance;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SerializationHelper;
import weka.core.SparseInstance;
import weka.core.Utils;
import weka.core.WeightedAttributesHandler;
import weka.core.WeightedInstancesHandler;
import weka.core.WekaException;
import weka.filters.SimpleBatchFilter;

public class AddClassification
extends SimpleBatchFilter
implements WeightedAttributesHandler,
WeightedInstancesHandler {
    private static final long serialVersionUID = -1931467132568441909L;
    protected Classifier m_Classifier = new ZeroR();
    protected File m_SerializedClassifierFile = new File(System.getProperty("user.dir"));
    protected Classifier m_ActualClassifier = null;
    protected Instances m_SerializedHeader = null;
    protected boolean m_OutputClassification = false;
    protected boolean m_RemoveOldClass = false;
    protected boolean m_OutputDistribution = false;
    protected boolean m_OutputErrorFlag = false;

    @Override
    public String globalInfo() {
        return "A filter for adding the classification, the class distribution and an error flag to a dataset with a classifier. The classifier is either trained on the data itself or provided as serialized model.";
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> result = new Vector<Option>();
        result.addElement(new Option("\tFull class name of classifier to use, followed\n\tby scheme options. eg:\n\t\t\"weka.classifiers.bayes.NaiveBayes -D\"\n\t(default: weka.classifiers.rules.ZeroR)", "W", 1, "-W <classifier specification>"));
        result.addElement(new Option("\tInstead of training a classifier on the data, one can also provide\n\ta serialized model and use that for tagging the data.", "serialized", 1, "-serialized <file>"));
        result.addElement(new Option("\tAdds an attribute with the actual classification.\n\t(default: off)", "classification", 0, "-classification"));
        result.addElement(new Option("\tRemoves the old class attribute.\n\t(default: off)", "remove-old-class", 0, "-remove-old-class"));
        result.addElement(new Option("\tAdds attributes with the distribution for all classes \n\t(for numeric classes this will be identical to the attribute \n\toutput with '-classification').\n\t(default: off)", "distribution", 0, "-distribution"));
        result.addElement(new Option("\tAdds an attribute indicating whether the classifier output \n\ta wrong classification (for numeric classes this is the numeric \n\tdifference).\n\t(default: off)", "error", 0, "-error"));
        result.addAll(Collections.list(super.listOptions()));
        return result.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.setOutputClassification(Utils.getFlag("classification", options));
        this.setRemoveOldClass(Utils.getFlag("remove-old-class", options));
        this.setOutputDistribution(Utils.getFlag("distribution", options));
        this.setOutputErrorFlag(Utils.getFlag("error", options));
        boolean serializedModel = false;
        String tmpStr = Utils.getOption("serialized", options);
        if (tmpStr.length() != 0) {
            File file = new File(tmpStr);
            if (!file.exists()) {
                throw new FileNotFoundException("File '" + file.getAbsolutePath() + "' not found!");
            }
            if (file.isDirectory()) {
                throw new FileNotFoundException("'" + file.getAbsolutePath() + "' points to a directory not a file!");
            }
            this.setSerializedClassifierFile(file);
            serializedModel = true;
        } else {
            this.setSerializedClassifierFile(null);
        }
        if (!serializedModel) {
            String[] tmpOptions;
            tmpStr = Utils.getOption('W', options);
            if (tmpStr.length() == 0) {
                tmpStr = ZeroR.class.getName();
            }
            if ((tmpOptions = Utils.splitOptions(tmpStr)).length == 0) {
                throw new Exception("Invalid classifier specification string");
            }
            tmpStr = tmpOptions[0];
            tmpOptions[0] = "";
            this.setClassifier(AbstractClassifier.forName(tmpStr, tmpOptions));
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        File file;
        Vector<String> result = new Vector<String>();
        if (this.getOutputClassification()) {
            result.add("-classification");
        }
        if (this.getRemoveOldClass()) {
            result.add("-remove-old-class");
        }
        if (this.getOutputDistribution()) {
            result.add("-distribution");
        }
        if (this.getOutputErrorFlag()) {
            result.add("-error");
        }
        if ((file = this.getSerializedClassifierFile()) != null && !file.isDirectory()) {
            result.add("-serialized");
            result.add(file.getAbsolutePath());
        } else {
            result.add("-W");
            result.add(this.getClassifierSpec());
        }
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[result.size()]);
    }

    @Override
    protected void reset() {
        super.reset();
        this.m_ActualClassifier = null;
        this.m_SerializedHeader = null;
    }

    protected Classifier getActualClassifier() {
        block6: {
            if (this.m_ActualClassifier == null) {
                try {
                    File file = this.getSerializedClassifierFile();
                    if (!file.isDirectory()) {
                        ObjectInputStream ois = SerializationHelper.getObjectInputStream(new FileInputStream(file));
                        this.m_ActualClassifier = (Classifier)ois.readObject();
                        this.m_SerializedHeader = null;
                        try {
                            this.m_SerializedHeader = (Instances)ois.readObject();
                        }
                        catch (Exception e2) {
                            this.m_SerializedHeader = null;
                        }
                        ois.close();
                        break block6;
                    }
                    this.m_ActualClassifier = AbstractClassifier.makeCopy(this.m_Classifier);
                }
                catch (Exception e3) {
                    this.m_ActualClassifier = null;
                    System.err.println("Failed to instantiate classifier:");
                    e3.printStackTrace();
                }
            }
        }
        return this.m_ActualClassifier;
    }

    @Override
    protected void testInputFormat(Instances instanceInfo) throws Exception {
        Classifier classifier = this.getActualClassifier();
        if (classifier instanceof InputMappedClassifier) {
            Instances trainingData = ((InputMappedClassifier)classifier).getModelHeader(new Instances(instanceInfo, 0));
            this.getCapabilities(trainingData).testWithFail(trainingData);
        } else {
            this.getCapabilities(instanceInfo).testWithFail(instanceInfo);
        }
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result;
        if (this.getActualClassifier() == null) {
            result = super.getCapabilities();
            result.disableAll();
        } else {
            result = this.getActualClassifier().getCapabilities();
        }
        result.setMinimumNumberInstances(0);
        return result;
    }

    public String classifierTipText() {
        return "The classifier to use for classification.";
    }

    public void setClassifier(Classifier value) {
        this.m_Classifier = value;
    }

    public Classifier getClassifier() {
        return this.m_Classifier;
    }

    protected String getClassifierSpec() {
        Classifier c2 = this.getClassifier();
        String result = c2.getClass().getName();
        if (c2 instanceof OptionHandler) {
            result = result + " " + Utils.joinOptions(((OptionHandler)((Object)c2)).getOptions());
        }
        return result;
    }

    public String serializedClassifierFileTipText() {
        return "A file containing the serialized model of a trained classifier.";
    }

    public File getSerializedClassifierFile() {
        return this.m_SerializedClassifierFile;
    }

    public void setSerializedClassifierFile(File value) {
        if (value == null || !value.exists()) {
            value = new File(System.getProperty("user.dir"));
        }
        this.m_SerializedClassifierFile = value;
    }

    public String outputClassificationTipText() {
        return "Whether to add an attribute with the actual classification.";
    }

    public boolean getOutputClassification() {
        return this.m_OutputClassification;
    }

    public void setOutputClassification(boolean value) {
        this.m_OutputClassification = value;
    }

    public String removeOldClassTipText() {
        return "Whether to remove the old class attribute.";
    }

    public boolean getRemoveOldClass() {
        return this.m_RemoveOldClass;
    }

    public void setRemoveOldClass(boolean value) {
        this.m_RemoveOldClass = value;
    }

    public String outputDistributionTipText() {
        return "Whether to add attributes with the distribution for all classes (for numeric classes this will be identical to the attribute output with 'outputClassification').";
    }

    public boolean getOutputDistribution() {
        return this.m_OutputDistribution;
    }

    public void setOutputDistribution(boolean value) {
        this.m_OutputDistribution = value;
    }

    public String outputErrorFlagTipText() {
        return "Whether to add an attribute indicating whether the classifier output a wrong classification (for numeric classes this is the numeric difference).";
    }

    public boolean getOutputErrorFlag() {
        return this.m_OutputErrorFlag;
    }

    public void setOutputErrorFlag(boolean value) {
        this.m_OutputErrorFlag = value;
    }

    @Override
    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        int i;
        int classindex = -1;
        Attribute classAttribute = inputFormat.classIndex() >= 0 ? inputFormat.classAttribute() : null;
        Classifier classifier = this.getActualClassifier();
        if (!this.getSerializedClassifierFile().isDirectory()) {
            if (classifier instanceof InputMappedClassifier) {
                classAttribute = ((InputMappedClassifier)classifier).getModelHeader(new Instances(inputFormat, 0)).classAttribute();
            }
        } else if (classAttribute == null && !(classifier instanceof InputMappedClassifier)) {
            throw new IllegalArgumentException("AddClassification: class must be set if InputMappedClassifier is not used.");
        }
        ArrayList<Attribute> atts = new ArrayList<Attribute>();
        for (i = 0; i < inputFormat.numAttributes(); ++i) {
            if (i == inputFormat.classIndex() && this.getRemoveOldClass()) continue;
            if (i == inputFormat.classIndex()) {
                classindex = i;
            }
            atts.add((Attribute)inputFormat.attribute(i).copy());
        }
        if (this.getOutputClassification()) {
            if (classindex == -1) {
                classindex = atts.size();
            }
            atts.add(classAttribute.copy("classification"));
        }
        if (this.getOutputDistribution()) {
            if (classAttribute.isNominal()) {
                for (i = 0; i < classAttribute.numValues(); ++i) {
                    atts.add(new Attribute("distribution_" + classAttribute.value(i)));
                }
            } else {
                atts.add(new Attribute("distribution"));
            }
        }
        if (this.getOutputErrorFlag()) {
            if (classAttribute.isNominal()) {
                ArrayList<String> values = new ArrayList<String>();
                values.add("no");
                values.add("yes");
                atts.add(new Attribute("error", values));
            } else {
                atts.add(new Attribute("error"));
            }
        }
        Instances result = new Instances(inputFormat.relationName(), atts, 0);
        result.setClassIndex(classindex);
        return result;
    }

    @Override
    protected Instances process(Instances instances) throws Exception {
        if (!this.isFirstBatchDone()) {
            this.getActualClassifier();
            if (!this.getSerializedClassifierFile().isDirectory()) {
                if (this.m_SerializedHeader != null && !this.m_SerializedHeader.equalHeaders(instances) && !(this.m_ActualClassifier instanceof InputMappedClassifier)) {
                    throw new WekaException("Training header of classifier and filter dataset don't match:\n" + this.m_SerializedHeader.equalHeadersMsg(instances));
                }
            } else {
                this.m_ActualClassifier.buildClassifier(instances);
            }
        }
        Instances result = this.getOutputFormat();
        for (int i = 0; i < instances.numInstances(); ++i) {
            Instance oldInstance = instances.instance(i);
            double[] oldValues = oldInstance.toDoubleArray();
            double[] newValues = new double[result.numAttributes()];
            int start = 0;
            for (int j = 0; j < oldValues.length; ++j) {
                if (j == this.inputFormatPeek().classIndex() && this.getRemoveOldClass()) continue;
                newValues[start++] = oldValues[j];
            }
            if (this.getOutputClassification()) {
                newValues[start] = this.m_ActualClassifier.classifyInstance(oldInstance);
                ++start;
            }
            if (this.getOutputDistribution()) {
                double[] distribution = this.m_ActualClassifier.distributionForInstance(oldInstance);
                for (int n = 0; n < distribution.length; ++n) {
                    newValues[start] = distribution[n];
                    ++start;
                }
            }
            if (this.getOutputErrorFlag()) {
                Instance inst = oldInstance;
                if (this.m_ActualClassifier instanceof InputMappedClassifier) {
                    inst = ((InputMappedClassifier)this.m_ActualClassifier).constructMappedInstance(inst);
                }
                newValues[start] = instances.classIndex() < 0 ? Utils.missingValue() : (result.classAttribute().isNominal() ? (inst.classValue() == this.m_ActualClassifier.classifyInstance(oldInstance) ? 0.0 : 1.0) : this.m_ActualClassifier.classifyInstance(oldInstance) - inst.classValue());
                ++start;
            }
            AbstractInstance newInstance = oldInstance instanceof SparseInstance ? new SparseInstance(oldInstance.weight(), newValues) : new DenseInstance(oldInstance.weight(), newValues);
            this.copyValues(newInstance, false, oldInstance.dataset(), this.outputFormatPeek());
            result.add(newInstance);
        }
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 14534 $");
    }

    public static void main(String[] args) {
        AddClassification.runFilter(new AddClassification(), args);
    }
}

