001package org.nasdanika.ai.math;
002
003import java.util.List;
004
005import org.apache.commons.math3.analysis.UnivariateFunction;
006import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
007import org.apache.commons.math3.fitting.PolynomialCurveFitter;
008import org.apache.commons.math3.fitting.WeightedObservedPoint;
009
010/**
011 * Features and labels of size 1 - only the first element is taken
012 */
013public class PolynomialPredictorFitter extends AbstractUnivariateFunctionPredictorFitter {
014        
015        private int degree;
016
017        public PolynomialPredictorFitter(int degree) {
018                this.degree = degree;
019        }
020
021        @Override
022        protected UnivariateFunction fit(List<WeightedObservedPoint> points) {
023                PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree);
024                double[] params = fitter.fit(points);
025                return new PolynomialFunction(params);
026        }
027        
028}