001package org.nasdanika.ai; 002 003import java.util.function.Function; 004 005/** 006 * Creates a Function<double[][],double[]> predictor for each label element, 007 * predicts label elements using provided predictors and then combines 008 */ 009public abstract class AbstractMapReduceDoubleFitter extends AbstractDoubleFitter { 010 011 @Override 012 protected final Function<double[][], double[][]> fit(double[][] features, double[][] labels) { 013 @SuppressWarnings("unchecked") 014 Function<double[][], double[]>[] predictors = new Function[labels[0].length]; 015 for (int i = 0; i < predictors.length; ++i) { 016 double[] pLabels = new double[labels.length]; 017 for (int j = 0; j < labels.length; ++j) { 018 pLabels[j] = labels[j][i]; 019 } 020 predictors[i] = fit(features, pLabels); 021 } 022 023 return input -> { 024 double[][] predictions = new double[predictors.length][]; // To be transposed to output 025 for (int i = 0; i < predictors.length; ++i) { 026 predictions[i] = predictors[i].apply(input); 027 } 028 029 double[][] output = new double[input.length][predictors.length]; 030 for (int i = 0; i < output.length; ++i) { 031 for (int j = 0; j < predictors.length; ++j) { 032 output[i][j] = predictions[j][i]; 033 } 034 } 035 return output; 036 }; 037 } 038 039 protected abstract Function<double[][], double[]> fit(double[][] features, double[] labels); 040 041}