/*
 * Decompiled with CFR 0.152.
 */
package tl.lin.data.cfd;

import com.google.common.collect.Maps;
import java.util.Map;
import tl.lin.data.cfd.Object2IntConditionalFrequencyDistribution;
import tl.lin.data.fd.Object2IntFrequencyDistribution;
import tl.lin.data.fd.Object2IntFrequencyDistributionEntry;
import tl.lin.data.map.HMapKL;
import tl.lin.data.pair.PairOfObjectInt;

public class Object2IntConditionalFrequencyDistributionEntry<K extends Comparable<K>>
implements Object2IntConditionalFrequencyDistribution<K> {
    private final Map<K, Object2IntFrequencyDistribution<K>> distributions = Maps.newHashMap();
    private final HMapKL<K> marginals = new HMapKL();
    private long sumOfAllCounts = 0L;

    @Override
    public void set(K k, K cond, int v) {
        if (!this.distributions.containsKey(cond)) {
            Object2IntFrequencyDistributionEntry<K> fd = new Object2IntFrequencyDistributionEntry<K>();
            fd.set(k, v);
            this.distributions.put(cond, fd);
            this.marginals.increment(k, v);
            this.sumOfAllCounts += (long)v;
        } else {
            Object2IntFrequencyDistribution<K> fd = this.distributions.get(cond);
            int rv = fd.get(k);
            fd.set(k, v);
            this.distributions.put(cond, fd);
            this.marginals.increment(k, -rv + v);
            this.sumOfAllCounts = this.sumOfAllCounts - (long)rv + (long)v;
        }
    }

    @Override
    public void increment(K k, K cond) {
        this.increment(k, cond, 1);
    }

    @Override
    public void increment(K k, K cond, int v) {
        int cur = this.get(k, cond);
        if (cur == 0) {
            this.set(k, cond, v);
        } else {
            this.set(k, cond, cur + v);
        }
    }

    @Override
    public int get(K k, K cond) {
        if (!this.distributions.containsKey(cond)) {
            return 0;
        }
        return this.distributions.get(cond).get(k);
    }

    @Override
    public long getMarginalCount(K k) {
        return this.marginals.get(k);
    }

    @Override
    public Object2IntFrequencyDistribution<K> getConditionalDistribution(K cond) {
        if (this.distributions.containsKey(cond)) {
            return this.distributions.get(cond);
        }
        return new Object2IntFrequencyDistributionEntry();
    }

    @Override
    public long getSumOfAllCounts() {
        return this.sumOfAllCounts;
    }

    @Override
    public void check() {
        Object2IntFrequencyDistributionEntry m = new Object2IntFrequencyDistributionEntry();
        long totalSum = 0L;
        for (Object2IntFrequencyDistribution<K> object2IntFrequencyDistribution : this.distributions.values()) {
            long conditionalSum = 0L;
            for (PairOfObjectInt pairOfObjectInt : object2IntFrequencyDistribution) {
                conditionalSum += (long)pairOfObjectInt.getRightElement();
                m.increment(pairOfObjectInt.getLeftElement(), pairOfObjectInt.getRightElement());
            }
            if (conditionalSum != object2IntFrequencyDistribution.getSumOfCounts()) {
                throw new RuntimeException("Internal Error!");
            }
            totalSum += object2IntFrequencyDistribution.getSumOfCounts();
        }
        if (totalSum != this.getSumOfAllCounts()) {
            throw new RuntimeException("Internal Error! Got " + totalSum + ", Expected " + this.getSumOfAllCounts());
        }
        for (PairOfObjectInt pairOfObjectInt : m) {
            if ((long)pairOfObjectInt.getRightElement() == this.marginals.get(pairOfObjectInt.getLeftElement())) continue;
            throw new RuntimeException("Internal Error!");
        }
        for (PairOfObjectInt pairOfObjectInt : m) {
            if (pairOfObjectInt.getRightElement() == m.get(pairOfObjectInt.getLeftElement())) continue;
            throw new RuntimeException("Internal Error!");
        }
    }
}

