/*
 * Decompiled with CFR 0.152.
 */
package pl.allegro.tech.hermes.consumers.supervisor.workload.weighted;

import java.time.Clock;
import java.time.Duration;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import pl.allegro.tech.hermes.consumers.supervisor.workload.weighted.ConsumerNode;
import pl.allegro.tech.hermes.consumers.supervisor.workload.weighted.ConsumerNodeLoad;
import pl.allegro.tech.hermes.consumers.supervisor.workload.weighted.ExponentiallyWeightedMovingAverage;
import pl.allegro.tech.hermes.consumers.supervisor.workload.weighted.TargetWeightCalculator;
import pl.allegro.tech.hermes.consumers.supervisor.workload.weighted.Weight;
import pl.allegro.tech.hermes.consumers.supervisor.workload.weighted.WeightedWorkloadMetricsReporter;

public class ScoringTargetWeightCalculator
implements TargetWeightCalculator {
    private static final double MIN_SCORE = 0.01;
    private static final double MAX_SCORE = 1.0;
    private final WeightedWorkloadMetricsReporter metrics;
    private final Clock clock;
    private final Duration scoringWindowSize;
    private final double scoringGain;
    private final Map<String, ExponentiallyWeightedMovingAverage> scores = new HashMap<String, ExponentiallyWeightedMovingAverage>();

    public ScoringTargetWeightCalculator(WeightedWorkloadMetricsReporter metrics, Clock clock, Duration scoringWindowSize, double scoringGain) {
        this.metrics = metrics;
        this.clock = clock;
        this.scoringWindowSize = scoringWindowSize;
        this.scoringGain = scoringGain;
    }

    @Override
    public Map<String, Weight> calculate(Collection<ConsumerNode> consumers) {
        this.removeScoresForInactiveConsumers(consumers);
        this.metrics.reportCurrentWeights(consumers);
        Map<String, ConsumerNodeLoad> loadPerConsumer = this.mapConsumerIdToLoad(consumers);
        double targetCpuUtilization = this.calculateTargetCpuUtilization(loadPerConsumer);
        Map<String, Double> currentScores = this.calculateCurrentScores(loadPerConsumer);
        HashMap<String, Double> newScores = new HashMap<String, Double>();
        for (Map.Entry<String, Double> entry : currentScores.entrySet()) {
            String consumerId = entry.getKey();
            double cpuUtilization = loadPerConsumer.get(consumerId).getCpuUtilization();
            double error = targetCpuUtilization - cpuUtilization;
            double currentScore = entry.getValue();
            double newScore = this.calculateNewScore(consumerId, currentScore, error);
            newScores.put(consumerId, newScore);
            this.metrics.reportCurrentScore(consumerId, currentScore);
            this.metrics.reportProposedScore(consumerId, newScore);
            this.metrics.reportScoringError(consumerId, error);
        }
        Map<String, Weight> newWeights = this.calculateWeights(consumers, newScores);
        this.metrics.reportProposedWeights(newWeights);
        return newWeights;
    }

    private void removeScoresForInactiveConsumers(Collection<ConsumerNode> consumers) {
        Set consumerIds = consumers.stream().map(ConsumerNode::getConsumerId).collect(Collectors.toSet());
        this.scores.entrySet().removeIf(e -> !consumerIds.contains(e.getKey()));
    }

    private Map<String, ConsumerNodeLoad> mapConsumerIdToLoad(Collection<ConsumerNode> consumers) {
        return consumers.stream().collect(Collectors.toMap(ConsumerNode::getConsumerId, ConsumerNode::getInitialLoad));
    }

    private double calculateTargetCpuUtilization(Map<String, ConsumerNodeLoad> loadPerConsumer) {
        return loadPerConsumer.values().stream().filter(ConsumerNodeLoad::isDefined).mapToDouble(ConsumerNodeLoad::getCpuUtilization).average().orElse(0.0);
    }

    private Map<String, Double> calculateCurrentScores(Map<String, ConsumerNodeLoad> loadPerConsumer) {
        Map<String, Double> opsPerConsumer = loadPerConsumer.entrySet().stream().filter(e -> ((ConsumerNodeLoad)e.getValue()).isDefined()).collect(Collectors.toMap(Map.Entry::getKey, e -> ((ConsumerNodeLoad)e.getValue()).sumOperationsPerSecond()));
        double opsSum = opsPerConsumer.values().stream().mapToDouble(ops -> ops).sum();
        return opsPerConsumer.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> this.calculateCurrentScore((Double)e.getValue(), opsSum)));
    }

    private double calculateCurrentScore(double ops, double opsSum) {
        if (opsSum > 0.0) {
            return ops / opsSum;
        }
        return 0.0;
    }

    private double calculateNewScore(String consumerId, double currentScore, double error) {
        double rawScore = currentScore + this.scoringGain * error;
        ExponentiallyWeightedMovingAverage average = this.scores.computeIfAbsent(consumerId, ignore -> new ExponentiallyWeightedMovingAverage(this.scoringWindowSize));
        double avg = average.update(rawScore, this.clock.instant());
        return this.ensureScoreRanges(avg);
    }

    private double ensureScoreRanges(double score) {
        return Math.max(Math.min(score, 1.0), 0.01);
    }

    private Map<String, Weight> calculateWeights(Collection<ConsumerNode> consumers, Map<String, Double> newScores) {
        Weight sum = consumers.stream().map(ConsumerNode::getWeight).reduce(Weight.ZERO, Weight::add);
        Weight avgWeight = this.calculateAvgWeight(sum, consumers.size());
        List consumersWithoutScore = consumers.stream().filter(consumerNode -> !newScores.containsKey(consumerNode.getConsumerId())).collect(Collectors.toList());
        HashMap<String, Weight> newWeights = new HashMap<String, Weight>();
        for (ConsumerNode consumerNode2 : consumersWithoutScore) {
            newWeights.put(consumerNode2.getConsumerId(), avgWeight);
            sum = sum.subtract(avgWeight);
        }
        double newScoresSum = newScores.values().stream().mapToDouble(score -> score).sum();
        for (Map.Entry<String, Double> entry : newScores.entrySet()) {
            Weight weight = sum.multiply(entry.getValue() / newScoresSum);
            newWeights.put(entry.getKey(), weight);
        }
        return newWeights;
    }

    private Weight calculateAvgWeight(Weight sum, int consumerCount) {
        if (consumerCount == 0) {
            return Weight.ZERO;
        }
        return sum.divide(consumerCount);
    }
}

