package com.github.chen0040.glm.evaluators;

import com.github.chen0040.glm.utils.TupleTwo;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Collectors;

/* loaded from: input_file:com/github/chen0040/glm/evaluators/ConfusionMatrix.class */
public class ConfusionMatrix implements Serializable {
    private static final long serialVersionUID = 8446651320939507735L;
    private Map<TupleTwo<String, String>, Integer> matrix = new HashMap();
    private Set<String> labels = new HashSet();
    private transient ReadWriteLock readWriteLock = new ReentrantReadWriteLock();

    public void incCount(String str, String str2) {
        this.readWriteLock.writeLock().lock();
        try {
            this.labels.add(str);
            this.labels.add(str2);
            TupleTwo<String, String> tupleTwo = new TupleTwo<>(str, str2);
            this.matrix.put(tupleTwo, Integer.valueOf(this.matrix.getOrDefault(tupleTwo, 0).intValue() + 1));
            this.readWriteLock.writeLock().unlock();
        } catch (Throwable th) {
            this.readWriteLock.writeLock().unlock();
            throw th;
        }
    }

    public List<String> getLabels() {
        ArrayList arrayList = new ArrayList();
        this.readWriteLock.readLock().lock();
        try {
            arrayList.addAll((Collection) this.labels.stream().collect(Collectors.toList()));
            return arrayList;
        } finally {
            this.readWriteLock.readLock().unlock();
        }
    }

    public void setLabels(List<String> list) {
        this.readWriteLock.writeLock().lock();
        try {
            this.labels.clear();
            this.labels.addAll(list);
        } finally {
            this.readWriteLock.writeLock().unlock();
        }
    }

    public int getRowSum(String str) {
        List<String> labels = getLabels();
        int i = 0;
        for (int i2 = 0; i2 < labels.size(); i2++) {
            i += getCount(str, labels.get(i2));
        }
        return i;
    }

    public int getColumnSum(String str) {
        List<String> labels = getLabels();
        int i = 0;
        for (int i2 = 0; i2 < labels.size(); i2++) {
            i += getCount(labels.get(i2), str);
        }
        return i;
    }

    public int getCount(String str, String str2) {
        this.readWriteLock.readLock().lock();
        try {
            int intValue = this.matrix.getOrDefault(new TupleTwo(str, str2), 0).intValue();
            this.readWriteLock.readLock().unlock();
            return intValue;
        } catch (Throwable th) {
            this.readWriteLock.readLock().unlock();
            throw th;
        }
    }

    public void reset() {
        this.readWriteLock.writeLock().lock();
        try {
            this.matrix.clear();
        } finally {
            this.readWriteLock.writeLock().unlock();
        }
    }

    public Map<TupleTwo<String, String>, Integer> getMatrix() {
        return this.matrix;
    }

    public void setMatrix(Map<TupleTwo<String, String>, Integer> map) {
        this.matrix = map;
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        this.readWriteLock = new ReentrantReadWriteLock();
    }
}
