/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.LinkedList;
import java.util.Random;
import java.util.stream.DoubleStream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.utils.LinearModel;

class LinearModelTest {
    private static final Random random = new Random();

    LinearModelTest() {
    }

    @Test
    void testEmptyFitClassificationDoesNothing() {
        int size = 10;
        LinearModel linearModel = new LinearModel(size, true, random);
        LinkedList trainingSet = new LinkedList();
        linearModel.fit(trainingSet);
        org.junit.jupiter.api.Assertions.assertArrayEquals((double[])new double[size], (double[])linearModel.getWeights());
    }

    @Test
    void testEmptyFitRegressionDoesNothing() {
        int size = 10;
        LinearModel linearModel = new LinearModel(size, false, random);
        LinkedList trainingSet = new LinkedList();
        linearModel.fit(trainingSet);
        org.junit.jupiter.api.Assertions.assertArrayEquals((double[])new double[size], (double[])linearModel.getWeights());
    }

    @Test
    void testRegressionFit() {
        int size = 10;
        LinearModel linearModel = new LinearModel(size, false, random);
        LinkedList<ImmutablePair> trainingSet = new LinkedList<ImmutablePair>();
        for (int i = 0; i < 100; ++i) {
            double[] x = new double[size];
            for (int j = 0; j < size; ++j) {
                x[j] = (double)i / (1.0 * (double)j + (double)i);
            }
            Double y = DoubleStream.of(x).sum();
            trainingSet.add(new ImmutablePair((Object)x, (Object)y));
        }
        Assertions.assertThat((double)linearModel.fit(trainingSet)).isLessThan(1.0);
    }

    @Test
    void testClassificationFit() {
        int size = 10;
        LinearModel linearModel = new LinearModel(size, true, random);
        LinkedList<ImmutablePair> trainingSet = new LinkedList<ImmutablePair>();
        for (int i = 0; i < 100; ++i) {
            double[] x = new double[size];
            for (int j = 0; j < size; ++j) {
                x[j] = (double)i / (1.0 * (double)j + (double)i);
            }
            Double y = i % 2 == 0 ? 1.0 : 0.0;
            trainingSet.add(new ImmutablePair((Object)x, (Object)y));
        }
        Assertions.assertThat((double)linearModel.fit(trainingSet)).isLessThan(1.0);
    }
}

