/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Random;
import junit.framework.TestCase;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.junit.Test;

public class DMatrixTest {
    @Test
    public void testCreateFromDataIterator() throws XGBoostError {
        ArrayList<Float> labelall = new ArrayList<Float>();
        int nrep = 3000;
        LinkedList<LabeledPoint> blist = new LinkedList<LabeledPoint>();
        for (int i = 0; i < nrep; ++i) {
            LabeledPoint p = new LabeledPoint(0.1f + (float)i, new int[]{0, 2, 3}, new float[]{3.0f, 4.0f, 5.0f});
            blist.add(p);
            labelall.add(Float.valueOf(p.label()));
        }
        DMatrix dmat = new DMatrix(blist.iterator(), null);
        float[] labels = dmat.getLabel();
        for (int i = 0; i < labels.length; ++i) {
            TestCase.assertTrue((((Float)labelall.get(i)).floatValue() == labels[i] ? 1 : 0) != 0);
        }
    }

    @Test
    public void testCreateFromFile() throws XGBoostError {
        DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
        float[] labels = dmat.getLabel();
        TestCase.assertTrue((dmat.rowNum() == (long)labels.length ? 1 : 0) != 0);
        float[] weights = Arrays.copyOf(labels, labels.length);
        dmat.setWeight(weights);
        float[] dweights = dmat.getWeight();
        TestCase.assertTrue((boolean)Arrays.equals(weights, dweights));
    }

    @Test
    public void testCreateFromCSR() throws XGBoostError {
        float[] data = new float[]{1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 3.0f, 5.0f, 3.0f, 1.0f, 2.0f, 5.0f};
        int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
        long[] rowHeaders = new long[]{0L, 3L, 7L, 11L};
        DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
        TestCase.assertTrue((dmat1.rowNum() == 3L ? 1 : 0) != 0);
        float[] label1 = new float[]{1.0f, 0.0f, 1.0f};
        dmat1.setLabel(label1);
        float[] label2 = dmat1.getLabel();
        TestCase.assertTrue((boolean)Arrays.equals(label1, label2));
    }

    @Test
    public void testCreateFromCSREx() throws XGBoostError {
        float[] data = new float[]{1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 3.0f, 5.0f, 3.0f, 1.0f, 2.0f, 5.0f};
        int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
        long[] rowHeaders = new long[]{0L, 3L, 7L, 11L};
        DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, 5);
        TestCase.assertTrue((dmat1.rowNum() == 3L ? 1 : 0) != 0);
        float[] label1 = new float[]{1.0f, 0.0f, 1.0f};
        dmat1.setLabel(label1);
        float[] label2 = dmat1.getLabel();
        TestCase.assertTrue((boolean)Arrays.equals(label1, label2));
    }

    @Test
    public void testCreateFromCSC() throws XGBoostError {
        float[] data = new float[]{1.0f, 3.0f, 5.0f, 2.0f, 2.0f, 3.0f, 5.0f, 2.0f, 4.0f, 3.0f, 1.0f};
        int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
        long[] colHeaders = new long[]{0L, 4L, 7L, 11L};
        DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC);
        System.out.println(dmat1.rowNum());
        TestCase.assertTrue((dmat1.rowNum() == 5L ? 1 : 0) != 0);
        float[] label1 = new float[]{1.0f, 0.0f, 1.0f, 1.0f, 1.0f};
        dmat1.setLabel(label1);
        float[] label2 = dmat1.getLabel();
        TestCase.assertTrue((boolean)Arrays.equals(label1, label2));
    }

    @Test
    public void testCreateFromCSCEx() throws XGBoostError {
        float[] data = new float[]{1.0f, 3.0f, 5.0f, 2.0f, 2.0f, 3.0f, 5.0f, 2.0f, 4.0f, 3.0f, 1.0f};
        int[] rowIndex = new int[]{0, 1, 3, 4, 2, 3, 4, 0, 1, 2, 3};
        long[] colHeaders = new long[]{0L, 4L, 7L, 11L};
        DMatrix dmat1 = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, 5);
        System.out.println(dmat1.rowNum());
        TestCase.assertTrue((dmat1.rowNum() == 5L ? 1 : 0) != 0);
        float[] label1 = new float[]{1.0f, 0.0f, 1.0f, 1.0f, 1.0f};
        dmat1.setLabel(label1);
        float[] label2 = dmat1.getLabel();
        TestCase.assertTrue((boolean)Arrays.equals(label1, label2));
    }

    @Test
    public void testCreateFromDenseMatrix() throws XGBoostError {
        int nrow = 10;
        int ncol = 5;
        float[] data0 = new float[nrow * ncol];
        Random random = new Random();
        for (int i = 0; i < nrow * ncol; ++i) {
            data0[i] = random.nextFloat();
        }
        float[] label0 = new float[nrow];
        for (int i = 0; i < nrow; ++i) {
            label0[i] = random.nextFloat();
        }
        DMatrix dmat0 = new DMatrix(data0, nrow, ncol);
        dmat0.setLabel(label0);
        TestCase.assertTrue((dmat0.rowNum() == 10L ? 1 : 0) != 0);
        TestCase.assertTrue((dmat0.getLabel().length == 10 ? 1 : 0) != 0);
        float[] weights = new float[nrow];
        for (int i = 0; i < nrow; ++i) {
            weights[i] = random.nextFloat();
        }
        dmat0.setWeight(weights);
        TestCase.assertTrue((boolean)Arrays.equals(weights, dmat0.getWeight()));
    }

    @Test
    public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
        int nrow = 10;
        int ncol = 5;
        float[] data0 = new float[nrow * ncol];
        Random random = new Random();
        for (int i = 0; i < nrow * ncol; ++i) {
            data0[i] = i % 10 == 0 ? -0.1f : random.nextFloat();
        }
        float[] label0 = new float[nrow];
        for (int i = 0; i < nrow; ++i) {
            label0[i] = random.nextFloat();
        }
        DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
        dmat0.setLabel(label0);
        TestCase.assertTrue((dmat0.rowNum() == 10L ? 1 : 0) != 0);
        TestCase.assertTrue((dmat0.getLabel().length == 10 ? 1 : 0) != 0);
    }
}

