001package org.nasdanika.ai; 002 003import java.util.Arrays; 004import java.util.function.Function; 005 006/** 007 * Creates a Function<double[][],double[]> predictor for each label element with features for later elements including 008 * labels for earlier. 009 * During prediction earlier predictors outputs are used as inputs for later predictors. 010 * This is essentially how autoregression works. 011 */ 012public abstract class AbstractRecursiveDoubleFitter extends AbstractDoubleFitter { 013 014 @Override 015 protected final Function<double[][], double[][]> fit(double[][] features, double[][] labels) { 016 @SuppressWarnings("unchecked") 017 Function<double[][], double[]>[] predictors = new Function[labels[0].length]; 018 for (int i = 0; i < predictors.length; ++i) { 019 double[] pLabels = new double[labels.length]; 020 double[][] pFeatures = new double[features.length][]; 021 for (int j = 0; j < labels.length; ++j) { 022 pLabels[j] = labels[j][i]; 023 if (i == 0) { 024 pFeatures[j] = features[j]; 025 } else { 026 pFeatures[j] = Arrays.copyOf(features[j], features[j].length + i); 027 System.arraycopy(labels[j], 0, pFeatures[j], features[j].length, i); 028 } 029 } 030 031 predictors[i] = fit(pFeatures, pLabels); 032 } 033 034 return input -> { 035 double[][] predictions = new double[predictors.length][]; // To be transposed to output 036 for (int i = 0; i < predictors.length; ++i) { 037 if (i == 0) { 038 predictions[i] = predictors[i].apply(input); 039 } else { 040 double[][] pInput = new double[input.length][]; 041 for (int j = 0; j < input.length; ++j) { 042 pInput[j] = Arrays.copyOf(input[j], input[j].length + i); 043 for (int k = 0; k < i; ++k) { 044 pInput[j][input[j].length + k] = predictions[k][j]; 045 } 046 } 047 predictions[i] = predictors[i].apply(pInput); 048 } 049 } 050 051 double[][] output = new double[input.length][predictors.length]; 052 for (int i = 0; i < output.length; ++i) { 053 for (int j = 0; j < predictors.length; ++j) { 054 output[i][j] = predictions[j][i]; 055 } 056 } 057 return output; 058 }; 059 } 060 061 protected abstract Function<double[][], double[]> fit(double[][] features, double[] labels); 062 063}