001package org.nasdanika.ai; 002 003import java.util.ArrayList; 004import java.util.Arrays; 005import java.util.Collection; 006import java.util.List; 007import java.util.function.BinaryOperator; 008import java.util.function.Function; 009 010import org.nasdanika.ai.FittedPredictor.ErrorComputer; 011import org.nasdanika.ai.FittedPredictor.Fitter; 012 013import reactor.core.publisher.Flux; 014import reactor.core.publisher.Mono; 015 016/** 017 * Collects all samples to double[][] features and double[][] labels, 018 * then calls fit(double[][] features, double[][] labels) to fit {@link Function}<double[][],double[][]> 019 * which makes predictions for multiple features at once. 020 * This allows to use a single base class for multiple types of predictors 021 */ 022public abstract class AbstractDoubleFitter implements FittedPredictor.Fitter<double[], double[], Double> { 023 024 private record FeaturesMapped<S>(double[] features, S sample) {} 025 private record Data(double[] features, double[] labels) {} 026 027 @Override 028 public <S> Mono<FittedPredictor<double[], double[], Double>> fitAsync( 029 Flux<S> samples, 030 Function<S, Mono<double[]>> featureMapper, 031 Function<S, Mono<double[]>> labelMapper) { 032 033 return samples 034 .flatMap(s -> featureMapper.apply(s).map(f -> new FeaturesMapped<>(f,s))) 035 .flatMap(fm -> featureMapper.apply(fm.sample()).map(l -> new Data(fm.features(),l))) 036 .collectList() 037 .map(dataList -> { 038 double[][] features = new double[dataList.size()][]; 039 double[][] labels = new double[dataList.size()][]; 040 int idx = 0; 041 for (Data data: dataList) { 042 features[idx] = data.features(); 043 labels[idx++] = data.labels(); 044 } 045 046 Function<double[][],double[][]> predictor = fit(features, labels); 047 return createPredictor(predictor, features, labels); 048 }); 049 } 050 051 @Override 052 public <S> FittedPredictor<double[], double[], Double> fit( 053 Collection<S> samples, 054 Function<S, double[]> featureMapper, 055 Function<S, double[]> labelMapper) { 056 057 double[][] features = new double[samples.size()][]; 058 double[][] labels = new double[samples.size()][]; 059 int idx = 0; 060 for (S sample: samples) { 061 features[idx] = featureMapper.apply(sample); 062 labels[idx++] = labelMapper.apply(sample); 063 } 064 065 Function<double[][],double[][]> predictor = fit(features, labels); 066 return createPredictor(predictor, features, labels); 067 } 068 069 protected abstract Function<double[][],double[][]> fit(double[][] features, double[][] labels); 070 071 protected FittedPredictor<double[], double[], Double> createPredictor( 072 Function<double[][],double[][]> predictor, 073 double[][] features, 074 double[][] labels) { 075 076 Double error = computeError(predictor, features, labels); 077 return new FittedPredictor<double[], double[], Double>() { 078 079 @Override 080 public double[] predict(double[] feature) { 081 return predictor.apply(new double[][] { feature })[0]; 082 } 083 084 @Override 085 public List<Sample<double[], double[]>> predict(Collection<double[]> input) { 086 double[][] inputArray = input.toArray(double[][]::new); 087 double[][] results = predictor.apply(inputArray); 088 List<Sample<double[], double[]>> ret = new ArrayList<>(); 089 for (int i = 0; i < inputArray.length; ++i) { 090 ret.add(new Sample<>(inputArray[i], results[i])); 091 } 092 return ret; 093 } 094 095 @Override 096 public Mono<double[]> predictAsync(double[] input) { 097 return Mono.fromSupplier(() -> predict(input)); 098 } 099 100 @Override 101 public Double getError() { 102 return error; 103 } 104 105 }; 106 107 } 108 109 protected static Double computeError( 110 Function<double[][],double[][]> predictor, 111 double[][] features, 112 double[][] labels) { 113 114 double[][] prediction = predictor.apply(features); 115 if (prediction == null) { 116 return null; 117 } 118 double total = 0.0; 119 int count = 0; 120 for (int i = 0; i < labels.length; ++i) { 121 double[] pe = prediction[i]; 122 if (pe != null) { 123 for (int j = 0; j < labels[j].length; ++j) { 124 double delta = pe[j] - labels[i][j]; 125 total += delta * delta; 126 ++count; 127 } 128 } 129 } 130 131 if (count == 0) { 132 return null; 133 } 134 135 return total / count; 136 } 137 138 public static Function<double[][], double[]> wrap(Function<double[], Double> predictor) { 139 return input -> { 140 double[] output = new double[input.length]; 141 for (int i = 0; i < input.length; ++i) { 142 output[i] = predictor.apply(input[i]); 143 } 144 return output; 145 }; 146 } 147 148 public Fitter<double[], double[], Double> compose(Fitter<double[], double[], Double> other) { 149 BinaryOperator<double[]> add = (a,b) -> { 150 double[] aCopy = Arrays.copyOf(a, a.length); 151 for (int i = 0; i < a.length; ++i) { 152 aCopy[i] += b[i]; 153 } 154 return aCopy; 155 }; 156 157 BinaryOperator<double[]> subtract = (a,b) -> { 158 double[] aCopy = Arrays.copyOf(a, a.length); 159 for (int i = 0; i < a.length; ++i) { 160 aCopy[i] -= b[i]; 161 } 162 return aCopy; 163 }; 164 165 ErrorComputer<double[], double[], Double> errorComputer = new ErrorComputer<double[], double[], Double>() { 166 167 @Override 168 public <S> Double computeError( 169 Predictor<double[], double[]> predictor, 170 Collection<S> samples, 171 Function<S, double[]> featureMapper, 172 Function<S, double[]> labelMapper) { 173 174 double[][] features = new double[samples.size()][]; 175 double[][] labels = new double[samples.size()][]; 176 177 int idx = 0; 178 for (S sample: samples) { 179 features[idx] = featureMapper.apply(sample); 180 labels[idx++] = labelMapper.apply(sample); 181 } 182 183 Function<double[][], double[][]> errorPredictor = input -> { 184 double[][] output = new double[input.length][]; 185 for (int i = 0; i < input.length; ++i) { 186 output[i] = predictor.predict(input[i]); 187 } 188 return output; 189 }; 190 return AbstractDoubleFitter.computeError( 191 errorPredictor, 192 features, 193 labels); 194 } 195 196 }; 197 198 return compose(other, add, subtract, errorComputer); 199 } 200 201 202 // TODO - stacking/composition binary operators to add/subtract labels. 203 // Fit this one, fit the next one on label and prediction difference. 204 // Predict by adding this prediciton to the next prediction. 205 // Adapters to a single double result with support of stacking/composition too? 206 207}