/*
 * Decompiled with CFR 0.152.
 */
package org.imixs.ml.api;

import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import javax.inject.Inject;
import javax.inject.Named;
import javax.ws.rs.Consumes;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.Response;
import org.imixs.melman.RestAPIException;
import org.imixs.melman.WorkflowClient;
import org.imixs.ml.api.TrainingApplication;
import org.imixs.ml.core.MLTrainingResult;
import org.imixs.ml.service.TrainingService;
import org.imixs.workflow.ItemCollection;
import org.imixs.workflow.util.JSONParser;
import org.imixs.workflow.xml.XMLDataCollectionAdapter;
import org.imixs.workflow.xml.XMLDocument;
import org.imixs.workflow.xml.XMLDocumentAdapter;

@Named
@Path(value="training")
@Produces(value={"application/xml", "application/json"})
public class TrainingResource {
    @Inject
    TrainingService trainingService;
    private static Logger logger = Logger.getLogger(TrainingResource.class.getName());

    @POST
    @Consumes(value={"application/xml", "application/json"})
    public Response trainData(XMLDocument xmlConfig) {
        Object mainLog = "";
        HashMap<String, ItemCollection> trainingDataSet = null;
        ItemCollection config = XMLDocumentAdapter.putDocument((XMLDocument)xmlConfig);
        logger.info("...starting training....");
        logger.fine("......model=" + config.getItemValueString("ml.training.model"));
        logger.info("......required qualityLevel=" + config.getItemValueString("ml.training.quality"));
        logger.info("......ocrMode=" + config.getItemValueString("tika.ocrmode"));
        try {
            WorkflowClient worklowClient = TrainingApplication.buildWorkflowClient((ItemCollection)config);
            List itemNames = config.getItemValue("workflow.entities");
            itemNames.add("ml.definitions");
            if (itemNames.contains("$file") || itemNames.contains("$snapshotid")) {
                logger.severe("$file and $snapshot must not be included in the workflow.entities!");
                System.exit(0);
            }
            String encodedQuery = URLEncoder.encode(config.getItemValueString("workflow.query"), StandardCharsets.UTF_8.toString());
            Object queryURL = "documents/search/" + encodedQuery + "?sortBy=$modified&sortReverse=true";
            queryURL = (String)queryURL + "&pageSize=" + config.getItemValueInteger("workflow.pagesize") + "&pageIndex=" + config.getItemValueInteger("workflow.pageindex");
            queryURL = TrainingApplication.appendItenNames((String)queryURL, (List)itemNames);
            logger.info("......select workitems: " + (String)queryURL);
            List documents = worklowClient.getCustomResource((String)queryURL);
            logger.info("...... " + documents.size() + " documents found");
            trainingDataSet = new HashMap<String, ItemCollection>();
            for (ItemCollection doc : documents) {
                trainingDataSet.put(doc.getUniqueID(), doc);
            }
            int trainingInterations = config.getItemValueInteger("ml.training.iterations");
            double trainigDropoutRate = config.getItemValueDouble("ml.training.dropoutrate");
            if (trainigDropoutRate < 0.0 || trainigDropoutRate >= 1.0) {
                logger.warning("WRONG ml.training.dropoutrate : " + trainigDropoutRate + " should be between 0 and 1");
                trainigDropoutRate = 0.25;
            }
            logger.info("..... iterations=" + trainingInterations);
            logger.info("..... dropoutrate=" + trainigDropoutRate);
            for (int iteration = 0; iteration < trainingInterations; ++iteration) {
                logger.info("\n\n\n---- Starting " + iteration + ". Training Iteration ----\n\n");
                String log = this.startSingleTrainingIteration(trainingDataSet, trainigDropoutRate, worklowClient, config);
                mainLog = (String)mainLog + log;
                logger.info(log);
                logger.info("\\n\\n\\n---- Training Iteration Completed -----");
            }
        }
        catch (UnsupportedEncodingException | RestAPIException e) {
            logger.warning("Failed to query documents: " + e.getMessage());
            e.printStackTrace();
        }
        logger.info(" ");
        logger.info(" ");
        logger.info("****************Trainng Iteration Completed ***********************");
        logger.info("SUMMMARY:");
        logger.info((String)mainLog);
        ItemCollection stats = new ItemCollection();
        return Response.ok((Object)XMLDataCollectionAdapter.getDataCollection((ItemCollection)stats), (String)"application/xml").build();
    }

    private String startSingleTrainingIteration(Map<String, ItemCollection> trainingDataSet, double trainigDropoutRate, WorkflowClient worklowClient, ItemCollection config) {
        int countTotal = 0;
        int countQualityGood = 0;
        int countQualityLow = 0;
        int countQualityBad = 0;
        double nerFactor = -1.0;
        double allNerFactors = 0.0;
        ArrayList<String> iterraionUniqueIDs = new ArrayList<String>(trainingDataSet.keySet());
        Collections.shuffle(iterraionUniqueIDs);
        int size = iterraionUniqueIDs.size();
        int targetSize = (int)((double)size - (double)size * trainigDropoutRate);
        while (iterraionUniqueIDs.size() > targetSize) {
            iterraionUniqueIDs.remove(0);
        }
        int currentCount = 0;
        for (String uniqueid : iterraionUniqueIDs) {
            ItemCollection doc = trainingDataSet.get(uniqueid);
            ++currentCount;
            ++countTotal;
            logger.fine("...... train " + uniqueid + "...");
            MLTrainingResult trainingResult = this.trainingService.trainWorkitemData(config, doc, worklowClient);
            if (trainingResult != null) {
                String resultData;
                switch (trainingResult.getQualityLevel()) {
                    case 10: {
                        ++countQualityGood;
                        break;
                    }
                    case 4: {
                        ++countQualityLow;
                        break;
                    }
                    default: {
                        ++countQualityBad;
                    }
                }
                if ((resultData = trainingResult.getData()) == null || resultData.isEmpty()) continue;
                try {
                    String nerString = JSONParser.getKey((String)"ner", (String)resultData);
                    double newNerFactor = Double.parseDouble(nerString);
                    nerFactor = (allNerFactors += newNerFactor) / (double)currentCount;
                }
                catch (Exception e) {
                    logger.severe("failed to parse training result (ner)");
                }
                continue;
            }
            ++countQualityBad;
        }
        DecimalFormat df = new DecimalFormat("###.##");
        String log = "\n......documents trained in total = " + countTotal + "\n";
        log = log + "  ......     quality level GOOD = " + df.format((double)countQualityGood / (double)countTotal * 100.0) + "%  (" + countQualityGood + ")\n";
        log = log + "  ......      quality level LOW = " + df.format((double)countQualityLow / (double)countTotal * 100.0) + "%  (" + countQualityLow + ")\n";
        log = log + "  ......      quality level BAD = " + df.format((double)countQualityBad / (double)countTotal * 100.0) + "%  (" + countQualityBad + ")";
        log = log + "\n  ......            average NER = " + nerFactor;
        log = log + "\n";
        return log;
    }
}

