001package org.nasdanika.ai; 002 003import java.util.Collection; 004import java.util.function.BinaryOperator; 005import java.util.function.Function; 006 007import reactor.core.publisher.Flux; 008import reactor.core.publisher.Mono; 009 010/** 011 * A predictor which is fitted (trained) 012 * @param <F> 013 * @param <L> 014 * @param <E> 015 */ 016public interface FittedPredictor<F,L,E> extends Predictor<F,L> { 017 018 interface ErrorComputer<F,L,E> { 019 020 <S> E computeError(Predictor<F,L> predictor, 021 Collection<S> samples, 022 Function<S, F> featureMapper, 023 Function<S, L> labelMapper); 024 025 } 026 027 interface Fitter<F,L,E> { 028 029 /** 030 * Creates a predictor by fitting a collection of samples. 031 * @param <S> 032 * @param samples 033 * @param featureMapper 034 * @param labelMapper 035 * @return 036 */ 037 default <S> FittedPredictor<F,L,E> fit( 038 Collection<S> samples, 039 Function<S,F> featureMapper, 040 Function<S,L> labelMapper) { 041 042 return fitAsync( 043 Flux.fromIterable(samples), 044 s -> Mono.fromSupplier(() -> featureMapper.apply(s)), 045 s -> Mono.fromSupplier(() -> labelMapper.apply(s))).block(); 046 } 047 048 /** 049 * Creates a predictor by fitting a flux of samples. 050 * @param <S> 051 * @param samples 052 * @param featureMapper 053 * @param labelMapper 054 * @return A mono providing a predictor. 055 * The mono can publish a predictor before the flux is finished and update the predictor with new 056 * samples from the flux. 057 */ 058 <S> Mono<FittedPredictor<F,L,E>> fitAsync( 059 Flux<S> samples, 060 Function<S,Mono<F>> featureMapper, 061 Function<S,Mono<L>> labelMapper); 062 063 default <G> Fitter<G,L,E> adaptFeature(Function<G,F> featureMapper) { 064 065 return new Fitter<G,L,E>() { 066 067 @Override 068 public <S> Mono<FittedPredictor<G, L, E>> fitAsync( 069 Flux<S> samples, 070 Function<S, Mono<G>> theFeatureMapper, 071 Function<S, Mono<L>> labelMapper) { 072 073 Function<S, Mono<F>> featureMapperChain = theFeatureMapper.andThen(m -> m.map(featureMapper)); 074 Mono<FittedPredictor<F, L, E>> fResult = Fitter.this.fitAsync(samples, featureMapperChain, labelMapper); 075 return fResult.map(fp -> fp.adaptFeature(featureMapper)); 076 } 077 078 @Override 079 public <S> FittedPredictor<G, L, E> fit( 080 Collection<S> samples, 081 Function<S, G> theFeatureMapper, 082 Function<S, L> labelMapper) { 083 084 FittedPredictor<F, L, E> predictor = Fitter.this.fit(samples, theFeatureMapper.andThen(featureMapper), labelMapper); 085 return predictor.adaptFeature(featureMapper); 086 } 087 088 }; 089 090 } 091 092 default <G> Fitter<G,L,E> adaptFeatureAsync(Function<G,Mono<F>> featureMapper) { 093 094 return new Fitter<G,L,E>() { 095 096 @Override 097 public <S> Mono<FittedPredictor<G, L, E>> fitAsync( 098 Flux<S> samples, 099 Function<S, Mono<G>> theFeatureMapper, 100 Function<S, Mono<L>> labelMapper) { 101 102 Function<S, Mono<F>> featureMapperChain = theFeatureMapper.andThen(m -> m.flatMap(featureMapper)); 103 Mono<FittedPredictor<F, L, E>> fResult = Fitter.this.fitAsync(samples, featureMapperChain, labelMapper); 104 return fResult.map(fp -> fp.adaptFeatureAsync(featureMapper)); 105 } 106 107 }; 108 109 } 110 111 default <M> Fitter<F,M,E> adaptLabel(Function<M,L> fitMapper, Function<L,M> predictMapper) { 112 113 return new Fitter<F,M,E>() { 114 115 @Override 116 public <S> FittedPredictor<F, M, E> fit( 117 Collection<S> samples, 118 Function<S, F> featureMapper, 119 Function<S, M> labelMapper) { 120 121 Function<S, L> fitMapperChain = labelMapper.andThen(fitMapper); 122 FittedPredictor<F, L, E> result = Fitter.this.fit(samples, featureMapper, fitMapperChain); 123 return result.adaptLabel(predictMapper); 124 } 125 126 @Override 127 public <S> Mono<FittedPredictor<F, M, E>> fitAsync( 128 Flux<S> samples, 129 Function<S, Mono<F>> featureMapper, 130 Function<S, Mono<M>> labelMapper) { 131 132 Function<S, Mono<L>> fitMapperChain = labelMapper.andThen(m -> m.map(fitMapper)); 133 Mono<FittedPredictor<F, L, E>> result = Fitter.this.fitAsync(samples, featureMapper, fitMapperChain); 134 return result.map(fp -> fp.adaptLabel(predictMapper)); 135 } 136 137 }; 138 139 } 140 141 default <M> Fitter<F,M,E> adaptLabelAsync(Function<M,Mono<L>> fitMapper, Function<L,Mono<M>> predictMapper) { 142 143 return new Fitter<F,M,E>() { 144 145 @Override 146 public <S> Mono<FittedPredictor<F, M, E>> fitAsync( 147 Flux<S> samples, 148 Function<S, Mono<F>> featureMapper, 149 Function<S, Mono<M>> labelMapper) { 150 151 Function<S, Mono<L>> fitMapperChain = labelMapper.andThen(m -> m.flatMap(fitMapper)); 152 Mono<FittedPredictor<F, L, E>> result = Fitter.this.fitAsync(samples, featureMapper, fitMapperChain); 153 return result.map(fp -> fp.adaptLabelAsync(predictMapper)); 154 } 155 156 }; 157 158 } 159 160 /** 161 * Composes two predictors by fitting this one, then computing label difference between this predictor predictions and 162 * labels and fitting the next one with the difference. 163 * Prediction is done by computing prediction for this one, then the other and adding the two together. 164 * @param other 165 * @param add 166 * @param subtract 167 * @return 168 */ 169 default Fitter<F,L,E> compose( 170 Fitter<F,L,E> other, 171 BinaryOperator<L> add, 172 BinaryOperator<L> subtract, 173 ErrorComputer<F, L, E> errorComputer) { 174 175 return new Fitter<F, L, E>() { 176 177 @Override 178 public <S> FittedPredictor<F, L, E> fit( 179 Collection<S> samples, 180 Function<S, F> featureMapper, 181 Function<S, L> labelMapper) { 182 183 FittedPredictor<F,L,E> thisPredictor = Fitter.this.fit(samples, featureMapper, labelMapper); 184 FittedPredictor<F, L, E> otherPredictor = other.fit( 185 samples, 186 featureMapper, 187 s -> { 188 L label = labelMapper.apply(s); 189 L prediction = thisPredictor.predict(featureMapper.apply(s)); 190 return subtract.apply(label, prediction); 191 }); 192 193 return new FittedPredictor<F,L,E>() { 194 195 @Override 196 public L predict(F feature) { 197 L thisPrediction = thisPredictor.predict(feature); 198 L otherPrediction = otherPredictor.predict(feature); 199 return add.apply(thisPrediction, otherPrediction); 200 } 201 202 @Override 203 public Mono<L> predictAsync(F input) { 204 Mono<L> thisPrediction = thisPredictor.predictAsync(input); 205 Mono<L> otherPrediction = otherPredictor.predictAsync(input); 206 return Mono.zip(thisPrediction, otherPrediction).map(tuple -> add.apply(tuple.getT1(), tuple.getT2())); 207 } 208 209 @Override 210 public E getError() { 211 return errorComputer == null ? null : errorComputer.computeError(this, samples, featureMapper, labelMapper); 212 } 213 214 }; 215 } 216 217 @Override 218 public <S> Mono<FittedPredictor<F, L, E>> fitAsync( 219 Flux<S> samples, 220 Function<S, Mono<F>> featureMapper, 221 Function<S, Mono<L>> labelMapper) { 222 223 throw new UnsupportedOperationException("Implement me!"); 224 225// Collection<S> samplesSoFar = Collections.synchronizedCollection(new ArrayList<>()); 226// samples.subscribe(samplesSoFar::add); 227// 228// Mono<FittedPredictor<F,L,E>> thisPredictorMono = Fitter.this.fitAsync(samples, featureMapper, labelMapper); 229// Mono<FittedPredictor<F, L, E>> otherPredictorMono = thisPredictorMono.flatMap(thisPredictor -> { 230// otherPredictorMono = other.fitAsync( 231// samples, 232// featureMapper, 233// s -> { 234// Mono<L> label = labelMapper.apply(s); 235// Mono<L> prediction = featureMapper.apply(s).flatMap(f -> thisPredictor.predictAsync(f)); thisPredictor.predictAsync(); 236// return subtract.apply(label, prediction); 237// }); 238// 239// }); 240// 241// return Mono.zip(thisPredictorMono, otherPredictorMono, (thisPredictor, otherPredictor) -> { 242// 243// return new FittedPredictor<F,L,E>() { 244// 245// @Override 246// public L predict(F feature) { 247// L thisPrediction = thisPredictor.predict(feature); 248// L otherPrediction = otherPredictor.predict(feature); 249// return add.apply(thisPrediction, otherPrediction); 250// } 251// 252// @Override 253// public Mono<L> predictAsync(F input) { 254// Mono<L> thisPrediction = thisPredictor.predictAsync(input); 255// Mono<L> otherPrediction = otherPredictor.predictAsync(input); 256// return Mono.zip(thisPrediction, otherPrediction).map(tuple -> add.apply(tuple.getT1(), tuple.getT2())); 257// } 258// 259// @Override 260// public E getError() { 261// return errorComputer == null ? null : errorComputer.computeError(this, samplesSoFar, featureMapper, labelMapper); 262// } 263// 264// }; 265// 266// }); 267 } 268 269 }; 270 271 } 272 273// default Fitter<F,L,E> composeAsync( 274// Fitter<F,L,E> other, 275// BiFunction<L,L,Mono<L>> add, 276// BiFunction<L,L,Mono<L>> subtract, 277// ErrorComputer<F, L, E> errorComputer) { 278// 279// 280// } 281 282 } 283 284 /** 285 * @return Fitting residual error 286 */ 287 E getError(); 288 289 290 @Override 291 default <G> FittedPredictor<G,L,E> adaptFeature(Function<G,F> featureMapper) { 292 293 return new FittedPredictor<G,L,E>() { 294 295 @Override 296 public L predict(G feature) { 297 return FittedPredictor.this.predict(featureMapper.apply(feature)); 298 } 299 300 @Override 301 public Mono<L> predictAsync(G feature) { 302 return FittedPredictor.this.predictAsync(featureMapper.apply(feature)); 303 } 304 305 @Override 306 public E getError() { 307 return FittedPredictor.this.getError(); 308 } 309 310 }; 311 312 } 313 314 default <G> FittedPredictor<G,L,E> adaptFeatureAsync(Function<G,Mono<F>> featureMapper) { 315 316 return new FittedPredictor<G,L,E>() { 317 318 @Override 319 public Mono<L> predictAsync(G feature) { 320 return featureMapper.apply(feature).flatMap(FittedPredictor.this::predictAsync); 321 } 322 323 @Override 324 public E getError() { 325 return FittedPredictor.this.getError(); 326 } 327 328 }; 329 330 } 331 332 default <M> FittedPredictor<F,M,E> adaptLabel(Function<L,M> labelMapper) { 333 334 return new FittedPredictor<F,M,E>() { 335 336 @Override 337 public M predict(F feature) { 338 return labelMapper.apply(FittedPredictor.this.predict(feature)); 339 } 340 341 @Override 342 public Mono<M> predictAsync(F feature) { 343 return FittedPredictor.this.predictAsync(feature).map(labelMapper); 344 } 345 346 @Override 347 public E getError() { 348 return FittedPredictor.this.getError(); 349 } 350 351 }; 352 353 } 354 355 default <M> FittedPredictor<F,M,E> adaptLabelAsync(Function<L,Mono<M>> labelMapper) { 356 357 return new FittedPredictor<F,M,E>() { 358 359 @Override 360 public Mono<M> predictAsync(F feature) { 361 return FittedPredictor.this.predictAsync(feature).flatMap(labelMapper); 362 } 363 364 @Override 365 public E getError() { 366 return FittedPredictor.this.getError(); 367 } 368 369 }; 370 371 } 372 373}