/*
 * Decompiled with CFR 0.152.
 */
package org.nasdanika.ai;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import org.nasdanika.ai.FittedPredictor;
import org.nasdanika.ai.Predictor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public abstract class AbstractDoubleFitter
implements FittedPredictor.Fitter<double[], double[], Double> {
    @Override
    public <S> Mono<FittedPredictor<double[], double[], Double>> fitAsync(Flux<S> samples, Function<S, Mono<double[]>> featureMapper, Function<S, Mono<double[]>> labelMapper) {
        return samples.flatMap(s -> ((Mono)featureMapper.apply(s)).map(f -> new FeaturesMapped<Object>((double[])f, s))).flatMap(fm -> ((Mono)featureMapper.apply(fm.sample())).map(l -> new Data(fm.features(), (double[])l))).collectList().map(dataList -> {
            double[][] features = new double[dataList.size()][];
            double[][] labels = new double[dataList.size()][];
            int idx = 0;
            for (Data data : dataList) {
                features[idx] = data.features();
                labels[idx++] = data.labels();
            }
            Function<double[][], double[][]> predictor = this.fit(features, labels);
            return this.createPredictor(predictor, features, labels);
        });
    }

    @Override
    public <S> FittedPredictor<double[], double[], Double> fit(Collection<S> samples, Function<S, double[]> featureMapper, Function<S, double[]> labelMapper) {
        double[][] features = new double[samples.size()][];
        double[][] labels = new double[samples.size()][];
        int idx = 0;
        for (S sample : samples) {
            features[idx] = featureMapper.apply(sample);
            labels[idx++] = labelMapper.apply(sample);
        }
        Function<double[][], double[][]> predictor = this.fit(features, labels);
        return this.createPredictor(predictor, features, labels);
    }

    protected abstract Function<double[][], double[][]> fit(double[][] var1, double[][] var2);

    protected FittedPredictor<double[], double[], Double> createPredictor(final Function<double[][], double[][]> predictor, double[][] features, double[][] labels) {
        final Double error = AbstractDoubleFitter.computeError(predictor, features, labels);
        return new FittedPredictor<double[], double[], Double>(){

            @Override
            public double[] predict(double[] feature) {
                return ((double[][])predictor.apply(new double[][]{feature}))[0];
            }

            @Override
            public List<Predictor.Sample<double[], double[]>> predict(Collection<double[]> input) {
                double[][] inputArray = (double[][])input.toArray(x$0 -> new double[x$0][]);
                double[][] results = (double[][])predictor.apply(inputArray);
                ArrayList<Predictor.Sample<double[], double[]>> ret = new ArrayList<Predictor.Sample<double[], double[]>>();
                for (int i = 0; i < inputArray.length; ++i) {
                    ret.add(new Predictor.Sample<double[], double[]>(inputArray[i], results[i]));
                }
                return ret;
            }

            @Override
            public Mono<double[]> predictAsync(double[] input) {
                return Mono.fromSupplier(() -> this.predict(input));
            }

            @Override
            public Double getError() {
                return error;
            }
        };
    }

    protected static Double computeError(Function<double[][], double[][]> predictor, double[][] features, double[][] labels) {
        double[][] prediction = predictor.apply(features);
        if (prediction == null) {
            return null;
        }
        double total = 0.0;
        int count = 0;
        for (int i = 0; i < labels.length; ++i) {
            double[] pe = prediction[i];
            if (pe == null) continue;
            for (int j = 0; j < labels[j].length; ++j) {
                double delta = pe[j] - labels[i][j];
                total += delta * delta;
                ++count;
            }
        }
        if (count == 0) {
            return null;
        }
        return total / (double)count;
    }

    public static Function<double[][], double[]> wrap(Function<double[], Double> predictor) {
        return input -> {
            double[] output = new double[((double[][])input).length];
            for (int i = 0; i < ((double[][])input).length; ++i) {
                output[i] = (Double)predictor.apply(input[i]);
            }
            return output;
        };
    }

    public FittedPredictor.Fitter<double[], double[], Double> compose(FittedPredictor.Fitter<double[], double[], Double> other) {
        BinaryOperator add = (a, b) -> {
            double[] aCopy = Arrays.copyOf(a, ((double[])a).length);
            for (int i = 0; i < ((double[])a).length; ++i) {
                int n = i;
                aCopy[n] = aCopy[n] + b[i];
            }
            return aCopy;
        };
        BinaryOperator subtract = (a, b) -> {
            double[] aCopy = Arrays.copyOf(a, ((double[])a).length);
            for (int i = 0; i < ((double[])a).length; ++i) {
                int n = i;
                aCopy[n] = aCopy[n] - b[i];
            }
            return aCopy;
        };
        FittedPredictor.ErrorComputer<double[], double[], Double> errorComputer = new FittedPredictor.ErrorComputer<double[], double[], Double>(this){

            @Override
            public <S> Double computeError(Predictor<double[], double[]> predictor, Collection<S> samples, Function<S, double[]> featureMapper, Function<S, double[]> labelMapper) {
                double[][] features = new double[samples.size()][];
                double[][] labels = new double[samples.size()][];
                int idx = 0;
                for (S sample : samples) {
                    features[idx] = featureMapper.apply(sample);
                    labels[idx++] = labelMapper.apply(sample);
                }
                Function<double[][], double[][]> errorPredictor = input -> {
                    double[][] output = new double[((double[][])input).length][];
                    for (int i = 0; i < ((double[][])input).length; ++i) {
                        output[i] = (double[])predictor.predict(input[i]);
                    }
                    return output;
                };
                return AbstractDoubleFitter.computeError(errorPredictor, features, labels);
            }
        };
        return this.compose(other, add, subtract, errorComputer);
    }

    private record Data(double[] features, double[] labels) {
    }

    private record FeaturesMapped<S>(double[] features, S sample) {
    }
}

