/*
 * Decompiled with CFR 0.152.
 */
package de.unidue.ltl.evaluation.visualization;

import de.unidue.ltl.evaluation.core.EvaluationData;
import de.unidue.ltl.evaluation.visualization.ConfusionMatrix;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Paint;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.text.DecimalFormat;
import java.util.List;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.annotations.XYAnnotation;
import org.jfree.chart.annotations.XYTextAnnotation;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.axis.NumberTickUnit;
import org.jfree.chart.axis.ValueAxis;
import org.jfree.chart.plot.Plot;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.GrayPaintScale;
import org.jfree.chart.renderer.PaintScale;
import org.jfree.chart.renderer.xy.XYBlockRenderer;
import org.jfree.chart.renderer.xy.XYItemRenderer;
import org.jfree.data.xy.DefaultXYZDataset;
import org.jfree.data.xy.XYDataset;
import org.jfree.data.xy.XYZDataset;

public class ConfusionMatrixHeatmap {
    private ConfusionMatrix<String> cf;
    private List<String> labels;

    public ConfusionMatrixHeatmap(EvaluationData<String> data) {
        ConfusionMatrix<String> confusionMatrix = new ConfusionMatrix<String>(data);
        this.cf = confusionMatrix;
        this.labels = this.cf.getLabels();
    }

    public ConfusionMatrixHeatmap(ConfusionMatrix<String> confusionMatrix) {
        this.cf = confusionMatrix;
        this.labels = this.cf.getLabels();
    }

    public void writePlot(File targetFile) throws IOException {
        DecimalFormat formatter = new DecimalFormat();
        formatter.setMaximumFractionDigits(2);
        formatter.setMinimumFractionDigits(2);
        XYBlockRenderer renderer = new XYBlockRenderer();
        renderer.setBlockHeight(1.0);
        renderer.setBlockWidth(1.0);
        renderer.setPaintScale((PaintScale)new GrayPaintScale(0.0, 1.0));
        for (int i = 0; i < this.labels.size(); ++i) {
            renderer.setSeriesShape(i, null);
            renderer.setSeriesCreateEntities(i, Boolean.valueOf(false));
        }
        NumberAxis axisA = new NumberAxis("Gold");
        NumberAxis axisB = new NumberAxis("Predicted");
        axisA.setTickUnit(new NumberTickUnit(1.0));
        axisB.setTickUnit(new NumberTickUnit(1.0));
        axisA.setRange(0.5, (double)this.labels.size() + 0.5);
        axisB.setRange(0.5, (double)this.labels.size() + 0.5);
        axisB.setInverted(true);
        XYZDataset dataset = this.getDataset();
        XYPlot aPlot = new XYPlot((XYDataset)dataset, (ValueAxis)axisA, (ValueAxis)axisB, (XYItemRenderer)renderer);
        aPlot.setOutlinePaint((Paint)Color.black);
        for (int i = 0; i < this.labels.size(); ++i) {
            for (int j = 0; j < this.labels.size(); ++j) {
                Double textValue = 1.0 - dataset.getZValue(i, j);
                XYTextAnnotation cellText = new XYTextAnnotation(formatter.format(textValue), dataset.getXValue(i, j), dataset.getYValue(i, j));
                cellText.setPaint((Paint)Color.black);
                aPlot.addAnnotation((XYAnnotation)cellText);
            }
        }
        JFreeChart chart = new JFreeChart("Confusion Matrix Heatmap", JFreeChart.DEFAULT_TITLE_FONT, (Plot)aPlot, false);
        chart.setBackgroundPaint((Paint)Color.white);
        ChartPanel panel = new ChartPanel(chart, true, true, true, true, true);
        panel.setPreferredSize(new Dimension(900, 850));
        ChartUtilities.writeChartAsPNG((OutputStream)new FileOutputStream(targetFile), (JFreeChart)chart, (int)(50 * this.labels.size()), (int)(40 * this.labels.size()));
    }

    private XYZDataset getDataset() {
        DefaultXYZDataset data = new DefaultXYZDataset();
        for (int i = 0; i < this.labels.size(); ++i) {
            String label = this.labels.get(i);
            double[][] series = this.getSeries(label, i, this.cf.getNumberOfEntries());
            data.addSeries((Comparable)((Object)(i + ": " + label)), series);
        }
        return data;
    }

    private double[][] getSeries(String label, int i, long n) {
        double[][] series = new double[3][this.labels.size()];
        for (int j = 0; j < this.labels.size(); ++j) {
            String currentConfusionLabel = this.labels.get(j);
            double confusionRatio = 0.0;
            if (!label.equals(currentConfusionLabel)) {
                confusionRatio = (double)this.cf.getNumberOfConfusions(label, currentConfusionLabel) / (double)n;
            }
            series[0][j] = i + 1;
            series[1][j] = j + 1;
            series[2][j] = 1.0 - confusionRatio;
        }
        return series;
    }
}

