/*
 * Decompiled with CFR 0.152.
 */
package tl.lin.data.cfd;

import tl.lin.data.cfd.Int2IntConditionalFrequencyDistribution;
import tl.lin.data.fd.Int2IntFrequencyDistribution;
import tl.lin.data.fd.Int2IntFrequencyDistributionEntry;
import tl.lin.data.fd.Int2LongFrequencyDistributionEntry;
import tl.lin.data.map.HMapIV;
import tl.lin.data.pair.PairOfInts;

public class Int2IntConditionalFrequencyDistributionEntry
implements Int2IntConditionalFrequencyDistribution {
    private final HMapIV<Int2IntFrequencyDistribution> distributions = new HMapIV();
    private final Int2LongFrequencyDistributionEntry marginals = new Int2LongFrequencyDistributionEntry();
    private long sumOfAllFrequencies = 0L;

    public void set(int k, int cond, int v) {
        if (!this.distributions.containsKey(cond)) {
            Int2IntFrequencyDistributionEntry fd = new Int2IntFrequencyDistributionEntry();
            fd.set(k, v);
            this.distributions.put(cond, fd);
            this.marginals.increment(k, v);
            this.sumOfAllFrequencies += (long)v;
        } else {
            Int2IntFrequencyDistribution fd = this.distributions.get(cond);
            int rv = fd.get(k);
            fd.set(k, v);
            this.distributions.put(cond, fd);
            this.marginals.increment(k, -rv + v);
            this.sumOfAllFrequencies = this.sumOfAllFrequencies - (long)rv + (long)v;
        }
    }

    public void increment(int k, int cond) {
        this.increment(k, cond, 1);
    }

    public void increment(int k, int cond, int v) {
        int cur = this.get(k, cond);
        if (cur == 0) {
            this.set(k, cond, v);
        } else {
            this.set(k, cond, cur + v);
        }
    }

    public int get(int k, int cond) {
        if (!this.distributions.containsKey(cond)) {
            return 0;
        }
        return this.distributions.get(cond).get(k);
    }

    public long getMarginalCount(int k) {
        return this.marginals.get(k);
    }

    public Int2IntFrequencyDistribution getConditionalDistribution(int cond) {
        if (this.distributions.containsKey(cond)) {
            return this.distributions.get(cond);
        }
        return new Int2IntFrequencyDistributionEntry();
    }

    public long getSumOfAllCounts() {
        return this.sumOfAllFrequencies;
    }

    public void check() {
        Int2IntFrequencyDistributionEntry m = new Int2IntFrequencyDistributionEntry();
        long totalSum = 0L;
        for (Int2IntFrequencyDistribution fd : this.distributions.values()) {
            long conditionalSum = 0L;
            for (PairOfInts pair : fd) {
                conditionalSum += (long)pair.getRightElement();
                m.increment(pair.getLeftElement(), pair.getRightElement());
            }
            if (conditionalSum != fd.getSumOfCounts()) {
                throw new RuntimeException("Internal Error!");
            }
            totalSum += fd.getSumOfCounts();
        }
        if (totalSum != this.getSumOfAllCounts()) {
            throw new RuntimeException("Internal Error! Got " + totalSum + ", Expected " + this.getSumOfAllCounts());
        }
        for (PairOfInts e : m) {
            if ((long)e.getRightElement() == this.marginals.get(e.getLeftElement())) continue;
            throw new RuntimeException("Internal Error!");
        }
        for (PairOfInts e : m) {
            if (e.getRightElement() == m.get(e.getLeftElement())) continue;
            throw new RuntimeException("Internal Error!");
        }
    }
}

