/*
 * Decompiled with CFR 0.152.
 */
package tl.lin.data.cfd;

import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import tl.lin.data.cfd.Int2IntConditionalFrequencyDistribution;
import tl.lin.data.fd.Int2IntFrequencyDistributionFastutil;
import tl.lin.data.fd.Int2LongFrequencyDistributionFastutil;
import tl.lin.data.pair.PairOfInts;

public class Int2IntConditionalFrequencyDistributionFastutil
implements Int2IntConditionalFrequencyDistribution {
    private final Int2ObjectMap<Int2IntFrequencyDistributionFastutil> distributions = new Int2ObjectOpenHashMap();
    private final Int2LongFrequencyDistributionFastutil marginals = new Int2LongFrequencyDistributionFastutil();
    private long sumOfAllFrequencies = 0L;

    public void set(int k, int cond, int v) {
        if (!this.distributions.containsKey(cond)) {
            Int2IntFrequencyDistributionFastutil fd = new Int2IntFrequencyDistributionFastutil();
            fd.set(k, v);
            this.distributions.put(cond, (Object)fd);
            this.marginals.increment(k, v);
            this.sumOfAllFrequencies += (long)v;
        } else {
            Int2IntFrequencyDistributionFastutil fd = (Int2IntFrequencyDistributionFastutil)this.distributions.get(cond);
            int rv = fd.get(k);
            fd.set(k, v);
            this.distributions.put(cond, (Object)fd);
            this.marginals.increment(k, -rv + v);
            this.sumOfAllFrequencies = this.sumOfAllFrequencies - (long)rv + (long)v;
        }
    }

    public void increment(int k, int cond) {
        this.increment(k, cond, 1);
    }

    public void increment(int k, int cond, int v) {
        int cur = this.get(k, cond);
        if (cur == 0) {
            this.set(k, cond, v);
        } else {
            this.set(k, cond, cur + v);
        }
    }

    public int get(int k, int cond) {
        if (!this.distributions.containsKey(cond)) {
            return 0;
        }
        return ((Int2IntFrequencyDistributionFastutil)this.distributions.get(cond)).get(k);
    }

    public long getMarginalCount(int k) {
        return this.marginals.get(k);
    }

    public Int2IntFrequencyDistributionFastutil getConditionalDistribution(int cond) {
        if (this.distributions.containsKey(cond)) {
            return (Int2IntFrequencyDistributionFastutil)this.distributions.get(cond);
        }
        return new Int2IntFrequencyDistributionFastutil();
    }

    public long getSumOfAllCounts() {
        return this.sumOfAllFrequencies;
    }

    public void check() {
        PairOfInts e;
        Int2IntFrequencyDistributionFastutil m = new Int2IntFrequencyDistributionFastutil();
        long totalSum = 0L;
        for (Int2IntFrequencyDistributionFastutil fd : this.distributions.values()) {
            long conditionalSum = 0L;
            Iterator<PairOfInts> i$ = fd.iterator();
            while (i$.hasNext()) {
                PairOfInts pair = i$.next();
                conditionalSum += (long)pair.getRightElement();
                m.increment(pair.getLeftElement(), pair.getRightElement());
            }
            if (conditionalSum != fd.getSumOfCounts()) {
                throw new RuntimeException("Internal Error!");
            }
            totalSum += fd.getSumOfCounts();
        }
        if (totalSum != this.getSumOfAllCounts()) {
            throw new RuntimeException("Internal Error! Got " + totalSum + ", Expected " + this.getSumOfAllCounts());
        }
        Iterator<PairOfInts> i$ = m.iterator();
        while (i$.hasNext()) {
            e = i$.next();
            if ((long)e.getRightElement() == this.marginals.get(e.getLeftElement())) continue;
            throw new RuntimeException("Internal Error!");
        }
        i$ = m.iterator();
        while (i$.hasNext()) {
            e = i$.next();
            if (e.getRightElement() == m.get(e.getLeftElement())) continue;
            throw new RuntimeException("Internal Error!");
        }
    }

    public void readFields(DataInput in) throws IOException {
        this.sumOfAllFrequencies = in.readLong();
        this.marginals.readFields(in);
        int sz = in.readInt();
        for (int i = 0; i < sz; ++i) {
            int key = in.readInt();
            Int2IntFrequencyDistributionFastutil map = new Int2IntFrequencyDistributionFastutil();
            map.readFields(in);
            this.distributions.put(key, (Object)map);
        }
    }

    public void write(DataOutput out) throws IOException {
        out.writeLong(this.sumOfAllFrequencies);
        this.marginals.write(out);
        out.writeInt(this.distributions.size());
        for (Int2ObjectMap.Entry e : this.distributions.int2ObjectEntrySet()) {
            out.writeInt(e.getIntKey());
            ((Int2IntFrequencyDistributionFastutil)e.getValue()).write(out);
        }
    }
}

