001package org.nasdanika.ai.math;
002
003import java.util.List;
004import java.util.function.Function;
005
006import org.apache.commons.math3.analysis.UnivariateFunction;
007import org.apache.commons.math3.fitting.WeightedObservedPoint;
008import org.apache.commons.math3.fitting.WeightedObservedPoints;
009import org.nasdanika.ai.AbstractMapReduceDoubleFitter;
010
011/**
012 * Features and labels of size 1 - only the first element is taken
013 */
014public abstract class AbstractUnivariateFunctionPredictorFitter extends AbstractMapReduceDoubleFitter {
015        
016        protected abstract UnivariateFunction fit(List<WeightedObservedPoint> points);
017
018        @Override
019        protected final Function<double[][], double[]> fit(double[][] features, double[] labels) {
020                WeightedObservedPoints wobs = new WeightedObservedPoints();
021                for (int i = 0; i < features.length; ++i) {
022                        double[] fi = features[i];
023                        if (fi.length != 1) {
024                                throw new IllegalArgumentException("Features array shall be of size 1");
025                        }
026                        wobs.add(fi[0], labels[i]);
027                }
028                
029                UnivariateFunction func = fit(wobs.toList());
030                
031                if (func == null) {
032                        return null;
033                }
034                
035                return wrap(input -> {
036                        if (input.length != 1) {
037                                throw new IllegalArgumentException("Input shall be of size 1");
038                        }
039                        return func.value(input[0]);
040                });
041        }
042        
043}