/*
 * Decompiled with CFR 0.152.
 */
package org.monarchinitiative.phenol.ontology.scoredist;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.IntConsumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.monarchinitiative.phenol.base.PhenolRuntimeException;
import org.monarchinitiative.phenol.ontology.data.Ontology;
import org.monarchinitiative.phenol.ontology.data.TermId;
import org.monarchinitiative.phenol.ontology.scoredist.ObjectScoreDistribution;
import org.monarchinitiative.phenol.ontology.scoredist.ScoreDistribution;
import org.monarchinitiative.phenol.ontology.scoredist.ScoreSamplingOptions;
import org.monarchinitiative.phenol.ontology.similarity.Similarity;
import org.monarchinitiative.phenol.utils.MersenneTwister;
import org.monarchinitiative.phenol.utils.ProgressReporter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class SimilarityScoreSampling {
    private static final Logger LOGGER = LoggerFactory.getLogger(SimilarityScoreSampling.class);
    private final Ontology ontology;
    private final Similarity similarity;
    private final ScoreSamplingOptions options;

    public SimilarityScoreSampling(Ontology ontology, Similarity similarity, ScoreSamplingOptions options) {
        this.ontology = ontology;
        this.similarity = similarity;
        this.options = (ScoreSamplingOptions)options.clone();
    }

    public Map<Integer, ScoreDistribution> performSampling(Map<Integer, ? extends Collection<TermId>> labels) {
        HashMap<Integer, ScoreDistribution> result = new HashMap<Integer, ScoreDistribution>();
        for (int numTerms = this.options.getMinNumTerms(); numTerms <= this.options.getMaxNumTerms(); ++numTerms) {
            result.put(numTerms, this.performSamplingForTermCount(labels, numTerms));
        }
        return result;
    }

    public ScoreDistribution performSamplingForTermCount(Map<Integer, ? extends Collection<TermId>> labels, int numTerms) {
        LOGGER.info("Running precomputation for {} world objects using {} query terms...", (Object)labels.size(), (Object)numTerms);
        ProgressReporter progressReport = new ProgressReporter(LOGGER, "objects", labels.size());
        progressReport.start();
        ConcurrentHashMap distributions = new ConcurrentHashMap();
        IntConsumer task = objectId -> {
            try {
                ObjectScoreDistribution dist = this.performComputation(objectId, (Collection)labels.get(objectId), numTerms);
                distributions.put(dist.getObjectId(), dist);
                progressReport.incCurrent();
            }
            catch (Exception e) {
                LOGGER.error("An exception occured in parallel processing!", (Throwable)e);
            }
        };
        int numThreads = this.options.getNumThreads();
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(numThreads, numThreads, 5L, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<Runnable>());
        Iterator objectIdIter = labels.keySet().stream().filter(this::selectObject).iterator();
        while (objectIdIter.hasNext()) {
            int objectId2 = (Integer)objectIdIter.next();
            threadPoolExecutor.submit(() -> task.accept(objectId2));
        }
        threadPoolExecutor.shutdown();
        try {
            threadPoolExecutor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
        }
        catch (InterruptedException e) {
            throw new PhenolRuntimeException("Could not wait for thread pool being done.", e);
        }
        progressReport.stop();
        LOGGER.info("Done running precomputation.");
        return new ScoreDistribution(numTerms, new HashMap<Integer, ObjectScoreDistribution>(distributions));
    }

    private boolean selectObject(Integer objectId) {
        if (this.options.getMinObjectId() != null && objectId < this.options.getMinObjectId()) {
            return false;
        }
        return this.options.getMaxObjectId() == null || objectId <= this.options.getMaxObjectId();
    }

    private ObjectScoreDistribution performComputation(int objectId, Collection<TermId> terms, int numTerms) {
        LOGGER.info("Running precomputation for world object {}.", (Object)objectId);
        MersenneTwister rng = new MersenneTwister();
        rng.setSeed(this.options.getSeed() + objectId);
        ObjectScoreDistribution result = new ObjectScoreDistribution(objectId, numTerms, this.options.getNumIterations(), this.sampleScoreCumulativeRelFreq(objectId, terms, numTerms, this.options.getNumIterations(), rng));
        LOGGER.info("Done computing precomputation for world object {}.", (Object)objectId);
        return result;
    }

    private TreeMap<Double, Double> sampleScoreCumulativeRelFreq(int objectId, Collection<TermId> terms, int numTerms, int numIterations, Random rng) {
        ArrayList<TermId> allTermIds = new ArrayList<TermId>(this.ontology.getNonObsoleteTermIds());
        Map counts = IntStream.range(0, numIterations - 1).boxed().map(i -> {
            List<TermId> randomTerms = SimilarityScoreSampling.selectRandomElements(allTermIds, numTerms, rng);
            double score = this.similarity.computeScore(randomTerms, terms);
            return (double)Math.round(score * 1000.0) / 1000.0;
        }).collect(Collectors.groupingByConcurrent(Function.identity(), Collectors.counting()));
        counts.putIfAbsent(0.0, 0L);
        TreeMap<Double, Double> result = new TreeMap<Double, Double>();
        for (Map.Entry entry : counts.entrySet()) {
            result.put((Double)entry.getKey(), (double)((Long)entry.getValue()));
        }
        double counter = 0.0;
        for (Map.Entry entry : result.entrySet()) {
            result.put((Double)entry.getKey(), (counter += ((Double)entry.getValue()).doubleValue()) / (double)numIterations);
        }
        return result;
    }

    private static <E> List<E> selectRandomElements(List<E> src, int count, Random rng) {
        if (count >= src.size()) {
            return src;
        }
        ArrayList<E> selected = new ArrayList<E>();
        Random random = new Random();
        int listSize = src.size();
        while (selected.size() < count) {
            int randomIndex = random.nextInt(listSize);
            E element = src.get(randomIndex);
            if (selected.contains(element)) continue;
            selected.add(element);
        }
        return selected;
    }
}

