001package org.nasdanika.ai;
002
003import java.util.Arrays;
004import java.util.function.Function;
005
006/**
007 * Creates a Function<double[][],double[]> predictor for each label element with features for later elements including
008 * labels for earlier. 
009 * During prediction earlier predictors outputs are used as inputs for later predictors. 
010 * This is essentially how autoregression works. 
011 */
012public abstract class AbstractRecursiveDoubleFitter extends AbstractDoubleFitter {
013
014        @Override
015        protected final Function<double[][], double[][]> fit(double[][] features, double[][] labels) {
016                @SuppressWarnings("unchecked")
017                Function<double[][], double[]>[] predictors = new Function[labels[0].length];
018                for (int i = 0; i < predictors.length; ++i) {
019                        double[] pLabels = new double[labels.length];
020                        double[][] pFeatures = new double[features.length][];
021                        for (int j = 0; j < labels.length; ++j) {
022                                pLabels[j] = labels[j][i];
023                                if (i == 0) {
024                                        pFeatures[j] = features[j];
025                                } else {
026                                        pFeatures[j] = Arrays.copyOf(features[j], features[j].length + i);
027                                        System.arraycopy(labels[j], 0, pFeatures[j], features[j].length, i);
028                                }
029                        }
030                        
031                        predictors[i] = fit(pFeatures, pLabels);
032                }
033                
034                return input -> {
035                        double[][] predictions = new double[predictors.length][]; // To be transposed to output
036                        for (int i = 0; i < predictors.length; ++i) {
037                                if (i == 0) {
038                                        predictions[i] = predictors[i].apply(input);
039                                } else {
040                                        double[][] pInput = new double[input.length][];
041                                        for (int j = 0; j < input.length; ++j) {
042                                                pInput[j] = Arrays.copyOf(input[j], input[j].length + i);
043                                                for (int k = 0; k < i; ++k) {
044                                                        pInput[j][input[j].length + k] = predictions[k][j];
045                                                }       
046                                        }
047                                        predictions[i] = predictors[i].apply(pInput);
048                                }
049                        }
050                                                
051                        double[][] output = new double[input.length][predictors.length];
052                         for (int i = 0; i < output.length; ++i) {
053                                for (int j = 0; j < predictors.length; ++j) {
054                                    output[i][j] = predictions[j][i];
055                                }
056                            }                   
057                        return output;
058                };
059        }
060        
061        protected abstract Function<double[][], double[]> fit(double[][] features, double[] labels);
062
063}