001package org.nasdanika.ai.math; 002 003import java.util.List; 004import java.util.function.Function; 005 006import org.apache.commons.math3.analysis.UnivariateFunction; 007import org.apache.commons.math3.fitting.WeightedObservedPoint; 008import org.apache.commons.math3.fitting.WeightedObservedPoints; 009import org.nasdanika.ai.AbstractMapReduceDoubleFitter; 010 011/** 012 * Features and labels of size 1 - only the first element is taken 013 */ 014public abstract class AbstractUnivariateFunctionPredictorFitter extends AbstractMapReduceDoubleFitter { 015 016 protected abstract UnivariateFunction fit(List<WeightedObservedPoint> points); 017 018 @Override 019 protected final Function<double[][], double[]> fit(double[][] features, double[] labels) { 020 WeightedObservedPoints wobs = new WeightedObservedPoints(); 021 for (int i = 0; i < features.length; ++i) { 022 double[] fi = features[i]; 023 if (fi.length != 1) { 024 throw new IllegalArgumentException("Features array shall be of size 1"); 025 } 026 wobs.add(fi[0], labels[i]); 027 } 028 029 UnivariateFunction func = fit(wobs.toList()); 030 031 if (func == null) { 032 return null; 033 } 034 035 return wrap(input -> { 036 if (input.length != 1) { 037 throw new IllegalArgumentException("Input shall be of size 1"); 038 } 039 return func.value(input[0]); 040 }); 041 } 042 043}