/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import java.util.Collection;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LinearModel {
    private static final Logger logger = LoggerFactory.getLogger(LinearModel.class);
    private static final double GOOD_LOSS_THRESHOLD = 0.1;
    private static final int MAX_NO_EPOCHS = 15;
    private static final double INITIAL_LEARNING_RATE = 0.01;
    private static final double DECAY_RATE = 0.01;
    private final double[] weights;
    private final boolean classification;
    private double bias = 0.0;

    public LinearModel(int size, boolean classification) {
        this.weights = new double[size];
        this.classification = classification;
    }

    public double fit(Collection<Pair<double[], Double>> trainingSet) {
        double[] sampleWeights = new double[trainingSet.size()];
        Arrays.fill(sampleWeights, 1.0);
        return this.fit(trainingSet, sampleWeights);
    }

    public double fit(Collection<Pair<double[], Double>> trainingSet, double[] sampleWeights) {
        double finalLoss = Double.NaN;
        if (trainingSet.isEmpty()) {
            logger.warn("fitting an empty training set");
            return finalLoss;
        }
        double lr = 0.01;
        int e = 0;
        while (this.checkFinalLoss(finalLoss) && e < 15) {
            double loss = 0.0;
            int i = 0;
            for (Pair<double[], Double> sample : trainingSet) {
                double[] doubles = (double[])sample.getLeft();
                double predictedOutput = this.predict(doubles);
                double targetOutput = (Double)sample.getRight();
                double diff = this.finiteOrZero(targetOutput - predictedOutput);
                if (diff != 0.0) {
                    loss += Math.abs(diff) / (double)trainingSet.size();
                    int j = 0;
                    while (j < this.weights.length) {
                        double v = lr * diff * doubles[j];
                        if (trainingSet.size() == sampleWeights.length) {
                            v *= sampleWeights[i];
                        }
                        v = this.finiteOrZero(v);
                        int n = j++;
                        this.weights[n] = this.weights[n] + v;
                        this.bias += lr * diff * sampleWeights[i];
                    }
                }
                ++i;
            }
            lr *= 1.0 / (1.0 + 0.01 * (double)e);
            finalLoss = loss;
            logger.debug("epoch {}, loss: {}", (Object)(++e), (Object)loss);
        }
        return finalLoss;
    }

    private boolean checkFinalLoss(double finalLoss) {
        return Double.isNaN(finalLoss) || finalLoss > 0.1;
    }

    private double finiteOrZero(double diff) {
        if (Double.isNaN(diff) || Double.isInfinite(diff)) {
            diff = 0.0;
        }
        return diff;
    }

    private double predict(double[] input) {
        double linearCombination = this.bias + IntStream.range(0, input.length).mapToDouble(i -> input[i] * this.weights[i]).sum();
        if (this.classification) {
            linearCombination = linearCombination >= 0.0 ? 1.0 : 0.0;
        }
        return linearCombination;
    }

    public double[] getWeights() {
        return this.weights;
    }
}

