001package org.nasdanika.ai;
002
003import java.util.function.Function;
004
005/**
006 * Creates a Function<double[][],double[]> predictor for each label element,
007 * predicts label elements using provided predictors and then combines 
008 */
009public abstract class AbstractMapReduceDoubleFitter extends AbstractDoubleFitter {
010
011        @Override
012        protected final Function<double[][], double[][]> fit(double[][] features, double[][] labels) {
013                @SuppressWarnings("unchecked")
014                Function<double[][], double[]>[] predictors = new Function[labels[0].length];
015                for (int i = 0; i < predictors.length; ++i) {
016                        double[] pLabels = new double[labels.length];
017                        for (int j = 0; j < labels.length; ++j) {
018                                pLabels[j] = labels[j][i];
019                        }
020                        predictors[i] = fit(features, pLabels);
021                }
022                
023                return input -> {
024                        double[][] predictions = new double[predictors.length][]; // To be transposed to output
025                        for (int i = 0; i < predictors.length; ++i) {
026                                predictions[i] = predictors[i].apply(input);
027                        }
028                                                
029                        double[][] output = new double[input.length][predictors.length];
030                         for (int i = 0; i < output.length; ++i) {
031                                for (int j = 0; j < predictors.length; ++j) {
032                                    output[i][j] = predictions[j][i];
033                                }
034                            }                   
035                        return output;
036                };
037        }
038        
039        protected abstract Function<double[][], double[]> fit(double[][] features, double[] labels);
040
041}