001package org.nasdanika.ai.math; 002 003import java.util.function.Function; 004 005import org.apache.commons.math3.analysis.MultivariateFunction; 006import org.nasdanika.ai.AbstractMapReduceDoubleFitter; 007 008public abstract class AbstractMultiivariateFunctionPredictorFitter extends AbstractMapReduceDoubleFitter { 009 010 protected abstract MultivariateFunction fitFunction(double[][] features, double[] labels); 011 012 @Override 013 protected final Function<double[][], double[]> fit(double[][] features, double[] labels) { 014 MultivariateFunction func = fitFunction(features, labels); 015 016 if (func == null) { 017 return null; 018 } 019 020 return input -> { 021 double[] output = new double[input.length]; 022 for (int i = 0; i < input.length; ++i) { 023 output[i] = func.value(input[i]); 024 } 025 return output; 026 }; 027 } 028 029}