001package org.nasdanika.ai.math;
002
003import java.util.function.Function;
004
005import org.apache.commons.math3.analysis.MultivariateFunction;
006import org.nasdanika.ai.AbstractMapReduceDoubleFitter;
007
008public abstract class AbstractMultiivariateFunctionPredictorFitter extends AbstractMapReduceDoubleFitter {
009        
010        protected abstract MultivariateFunction fitFunction(double[][] features, double[] labels);
011
012        @Override
013        protected final Function<double[][], double[]> fit(double[][] features, double[] labels) {
014                MultivariateFunction func = fitFunction(features, labels);
015                
016                if (func == null) {
017                        return null;
018                }
019                
020                return input -> {
021                        double[] output = new double[input.length];
022                        for (int i = 0; i < input.length; ++i) {
023                                output[i] = func.value(input[i]);
024                        }
025                        return output;
026                };
027        }
028        
029}