package ch.hevs.medgift.epad.plugins.service;

import java.util.ArrayList;
import java.util.List;

import javax.ejb.EJB;
import javax.ejb.Stateless;

import ch.hevs.medgift.plugins.common.models.EpadClassificationResponse;
import ch.hevs.medgift.plugins.common.models.QuantitativeImagingFeatures;
import weka.classifiers.Classifier;
import weka.classifiers.functions.LibSVM;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import com.google.gson.Gson;

/**
 * Stateless bean used to create model and classify ROIs
 *
 * @author Himmouche Abderrahmane
 */
@Stateless
public class SVMClassifierBean implements SVMClassifierBeanLocal {

    private static final String MODELS_PATH = "/data/quantimage/models/";

    @EJB
    PersistenceBeanLocal persist;

    public SVMClassifierBean() {

    }

    /**
     * @param project_id
     * @return List of analysis features of a given ROI
     */
    @Override
    public List<QuantitativeImagingFeatures> extractData(String project_id) {
        //Load all data belong to the same project ID
        List<QuantitativeImagingFeatures> list = persist.getAnalysisResults(project_id);
        return list;
    }

    /*
     * Train model from a set of regions
     */
    @Override
    public String trainModel(String modelName, List<QuantitativeImagingFeatures> featuresList) {
        List<QuantitativeImagingFeatures> results = featuresList;

        if (results == null || results.isEmpty())
            return null;

        Instances dataset = this.getInstancesFromListOfInstance(results);
        if (dataset == null)
            return null;

        // Pre-processing - set prediction class to the first attribute
        dataset.setClassIndex(0);

        // Build model
        LibSVM svm = new LibSVM();
        svm.setNormalize(true);
        svm.setCost(1000);//10
        svm.setGamma(10);//0
        //svm.setKernelType();
        //Flag to generate probability estimates
        svm.setProbabilityEstimates(true);
        try {
            svm.buildClassifier(dataset);

        } catch (Exception e) {
            e.printStackTrace();
        }

        //Save the model
        try {
            //Save into DB
            String modelFileName = this.saveModel(svm, modelName, results.get(0).getFeatureType());

            return modelFileName;
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        return null;
    }

    /**
     * Create and update models
     *
     * @param project_id
     */
    @Override
    public void createUpdateModel(String project_id) {
        List<QuantitativeImagingFeatures> results = this.extractData(project_id);
        trainModel(project_id, results);
    }

    /**
     * @param list
     * @return Data set
     */
    public Instances getInstancesFromListOfInstance(
            List<QuantitativeImagingFeatures> list) {
        // Construct Instances object
        ArrayList<Attribute> attributes = new ArrayList<Attribute>();
        List<Instance> instances = new ArrayList<Instance>();

        int numDimensions = list.get(0).getFeatureValues().length;

        // Get unique labels
        ArrayList<String> labels = new ArrayList<String>();
        for (int i = 0; i < list.size(); i++) {
            if (!labels.contains(list.get(i).getImageAnnotation().getName())) {
                labels.add(list.get(i).getImageAnnotation().getName());
            }
        }
        // Classification class
        Attribute label = new Attribute("Label", labels, 0);
        // Check that there are at least two differents labels
        if (labels.size() < 2)
            return null;

        for (int dim = 0; dim < numDimensions + 1; dim++) {
            // Create new attribute per dimension
            Attribute current;
            if (dim == 0)
                current = label;
            else
                current = new Attribute("Feature" + dim, dim);

			// Create an Instance for each result line
			if (dim == 0) {
				for (int obj = 0; obj < list.size(); obj++)
					instances.add(new SparseInstance(numDimensions + 1));
			}

			// Fill the value of the dimension for each instance
			for (int obj = 0; obj < list.size(); obj++) {
				if (dim == 0)
					instances.get(obj).setValue(current,
							list.get(obj).getImageAnnotation().getName());
				else
					instances.get(obj).setValue(current,
							list.get(obj).getFeatureValues()[dim -1]);
			}

			attributes.add(current);
		}

		// Create new dataset
		Instances newDataset = new Instances("Dataset", attributes,
				instances.size());
		// Fill the new dataset
		for (Instance inst : instances)
		{
			newDataset.add(inst);
		}

        System.out.println(newDataset.toSummaryString());

        return newDataset;
    }

    /**
     * Used to save model in server
     *
     * @param classifier
     * @param modelName
     * @throws Exception
     */
    public String saveModel(Classifier classifier, String modelName, String featureType) throws Exception {
        LibSVM libSVM = (LibSVM) classifier;
        System.out.println("gamma : " + libSVM.getGamma());
        weka.core.SerializationHelper.write(MODELS_PATH + modelName + "_" + featureType + ".model", libSVM);

        // Return model name
        return modelName + "_" + featureType + ".model";
    }

    /**
     * Classify ROI
     *
     * @param features of the current ROI to be classified
     *                 project id
     */
    @Override
    public EpadClassificationResponse classify(double[] features, String modelFileName) {
        EpadClassificationResponse result = new EpadClassificationResponse();
        try {
            //Check first if the model exists
            //...
            LibSVM svm = (LibSVM) weka.core.SerializationHelper.read(MODELS_PATH  + modelFileName + ".model");

            ArrayList<String> predClasses = new ArrayList<String>();
            ArrayList<Attribute> attributes = new ArrayList<Attribute>();

            List<QuantitativeImagingFeatures> list = persist.getAnalysisResultsForModel(modelFileName);

            // Get unique labels
            for (int i = 0; i < list.size(); i++) {

                if (!predClasses.contains(list.get(i).getImageAnnotation().getName())) {
                    predClasses.add(list.get(i).getImageAnnotation().getName());
                }
            }
            //Set predicted classes as first attribute
            Attribute att0 = new Attribute("Label", predClasses, 0);
            attributes.add(att0);

            //Add attribute's feature name
            for (int i = 0; i < features.length; i++)
                attributes.add(new Attribute("Feature" + i + 1, i + 1));
            //Create Instances with one Instance
            Instances instances = new Instances("DataSet", attributes, 1);
            instances.setClassIndex(0);

            SparseInstance ins = new SparseInstance(features.length + 1);

            for (int i = 0; i < features.length; i++) {
                ins.setValue(attributes.get(i + 1), features[i]);
            }
            instances.add(ins);
            Instance newInstance = instances.get(0);

            double preValue = svm.classifyInstance(newInstance);
            System.out.println(preValue);
            for (String label : predClasses) {
                double predProb = svm.distributionForInstance(newInstance)[predClasses.indexOf(label)];
                result.addClassification(label, new double[][]{{predProb}});
            }
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
        return result;
    }
}
