001package org.nasdanika.ai.math; 002 003import java.util.List; 004 005import org.apache.commons.math3.analysis.UnivariateFunction; 006import org.apache.commons.math3.analysis.polynomials.PolynomialFunction; 007import org.apache.commons.math3.fitting.PolynomialCurveFitter; 008import org.apache.commons.math3.fitting.WeightedObservedPoint; 009 010/** 011 * Features and labels of size 1 - only the first element is taken 012 */ 013public class PolynomialPredictorFitter extends AbstractUnivariateFunctionPredictorFitter { 014 015 private int degree; 016 017 public PolynomialPredictorFitter(int degree) { 018 this.degree = degree; 019 } 020 021 @Override 022 protected UnivariateFunction fit(List<WeightedObservedPoint> points) { 023 PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree); 024 double[] params = fitter.fit(points); 025 return new PolynomialFunction(params); 026 } 027 028}