001package org.nasdanika.ai;
002
003import java.util.Collection;
004import java.util.function.BinaryOperator;
005import java.util.function.Function;
006
007import reactor.core.publisher.Flux;
008import reactor.core.publisher.Mono;
009
010/**
011 * A predictor which is fitted (trained) 
012 * @param <F>
013 * @param <L>
014 * @param <E>
015 */
016public interface FittedPredictor<F,L,E> extends Predictor<F,L> {
017        
018        interface ErrorComputer<F,L,E> {
019                
020                <S> E computeError(Predictor<F,L> predictor, 
021                                Collection<S> samples, 
022                                Function<S, F> featureMapper,
023                                Function<S, L> labelMapper);
024                
025        }
026        
027        interface Fitter<F,L,E> {
028                
029                /**
030                 * Creates a predictor by fitting a collection of samples.
031                 * @param <S>
032                 * @param samples
033                 * @param featureMapper
034                 * @param labelMapper
035                 * @return
036                 */
037                default <S> FittedPredictor<F,L,E> fit(
038                                Collection<S> samples, 
039                                Function<S,F> featureMapper, 
040                                Function<S,L> labelMapper) {
041                        
042                        return fitAsync(
043                                Flux.fromIterable(samples),
044                                s -> Mono.fromSupplier(() -> featureMapper.apply(s)),
045                                s -> Mono.fromSupplier(() -> labelMapper.apply(s))).block();
046                }
047                
048                /**
049                 * Creates a predictor by fitting a flux of samples.
050                 * @param <S>
051                 * @param samples
052                 * @param featureMapper
053                 * @param labelMapper
054                 * @return A mono providing a predictor. 
055                 * The mono can publish a predictor before the flux is finished and update the predictor with new
056                 * samples from the flux.  
057                 */
058                <S> Mono<FittedPredictor<F,L,E>> fitAsync(
059                                Flux<S> samples, 
060                                Function<S,Mono<F>> featureMapper, 
061                                Function<S,Mono<L>> labelMapper);
062                                
063                default <G> Fitter<G,L,E> adaptFeature(Function<G,F> featureMapper) {
064                        
065                        return new Fitter<G,L,E>() {
066
067                                @Override
068                                public <S> Mono<FittedPredictor<G, L, E>> fitAsync(
069                                                Flux<S> samples, 
070                                                Function<S, Mono<G>> theFeatureMapper,
071                                                Function<S, Mono<L>> labelMapper) {
072                                        
073                                        Function<S, Mono<F>> featureMapperChain = theFeatureMapper.andThen(m -> m.map(featureMapper));
074                                        Mono<FittedPredictor<F, L, E>> fResult = Fitter.this.fitAsync(samples, featureMapperChain, labelMapper);
075                                        return fResult.map(fp -> fp.adaptFeature(featureMapper));
076                                }
077                                
078                                @Override
079                                public <S> FittedPredictor<G, L, E> fit(
080                                                Collection<S> samples, 
081                                                Function<S, G> theFeatureMapper,
082                                                Function<S, L> labelMapper) {
083                                        
084                                        FittedPredictor<F, L, E> predictor = Fitter.this.fit(samples, theFeatureMapper.andThen(featureMapper), labelMapper);
085                                        return predictor.adaptFeature(featureMapper);
086                                }
087                                
088                        };
089                                                
090                }       
091                        
092                default <G> Fitter<G,L,E> adaptFeatureAsync(Function<G,Mono<F>> featureMapper) {                        
093                        
094                        return new Fitter<G,L,E>() {
095
096                                @Override
097                                public <S> Mono<FittedPredictor<G, L, E>> fitAsync(
098                                                Flux<S> samples, 
099                                                Function<S, Mono<G>> theFeatureMapper,
100                                                Function<S, Mono<L>> labelMapper) {
101                                        
102                                        Function<S, Mono<F>> featureMapperChain = theFeatureMapper.andThen(m -> m.flatMap(featureMapper));
103                                        Mono<FittedPredictor<F, L, E>> fResult = Fitter.this.fitAsync(samples, featureMapperChain, labelMapper);
104                                        return fResult.map(fp -> fp.adaptFeatureAsync(featureMapper));
105                                }
106                                
107                        };
108                        
109                }       
110                
111                default <M> Fitter<F,M,E> adaptLabel(Function<M,L> fitMapper, Function<L,M> predictMapper) {
112                        
113                        return new Fitter<F,M,E>() {
114                                
115                                @Override
116                                public <S> FittedPredictor<F, M, E> fit(
117                                                Collection<S> samples, 
118                                                Function<S, F> featureMapper,
119                                                Function<S, M> labelMapper) {
120                                        
121                                        Function<S, L> fitMapperChain = labelMapper.andThen(fitMapper);
122                                        FittedPredictor<F, L, E> result = Fitter.this.fit(samples, featureMapper, fitMapperChain);
123                                        return result.adaptLabel(predictMapper);
124                                }
125
126                                @Override
127                                public <S> Mono<FittedPredictor<F, M, E>> fitAsync(
128                                                Flux<S> samples, 
129                                                Function<S, Mono<F>> featureMapper,
130                                                Function<S, Mono<M>> labelMapper) {
131                                        
132                                        Function<S, Mono<L>> fitMapperChain = labelMapper.andThen(m -> m.map(fitMapper));
133                                        Mono<FittedPredictor<F, L, E>> result = Fitter.this.fitAsync(samples, featureMapper, fitMapperChain);
134                                        return result.map(fp -> fp.adaptLabel(predictMapper));
135                                }
136                                
137                        };
138                                                
139                }       
140                        
141                default <M> Fitter<F,M,E> adaptLabelAsync(Function<M,Mono<L>> fitMapper, Function<L,Mono<M>> predictMapper) {
142                        
143                        return new Fitter<F,M,E>() {
144
145                                @Override
146                                public <S> Mono<FittedPredictor<F, M, E>> fitAsync(
147                                                Flux<S> samples, 
148                                                Function<S, Mono<F>> featureMapper,
149                                                Function<S, Mono<M>> labelMapper) {
150                                        
151                                        Function<S, Mono<L>> fitMapperChain = labelMapper.andThen(m -> m.flatMap(fitMapper));
152                                        Mono<FittedPredictor<F, L, E>> result = Fitter.this.fitAsync(samples, featureMapper, fitMapperChain);
153                                        return result.map(fp -> fp.adaptLabelAsync(predictMapper));
154                                }
155                                
156                        };
157                        
158                }       
159                
160                /**
161                 * Composes two predictors by fitting this one, then computing label difference between this predictor predictions and 
162                 * labels and fitting the next one with the difference. 
163                 * Prediction is done by computing prediction for this one, then the other and adding the two together.
164                 * @param other
165                 * @param add
166                 * @param subtract
167                 * @return
168                 */
169                default Fitter<F,L,E> compose(
170                                Fitter<F,L,E> other, 
171                                BinaryOperator<L> add, 
172                                BinaryOperator<L> subtract,
173                                ErrorComputer<F, L, E> errorComputer) {
174                        
175                        return new Fitter<F, L, E>() {
176                                
177                                @Override
178                                public <S> FittedPredictor<F, L, E> fit(
179                                                Collection<S> samples, 
180                                                Function<S, F> featureMapper,
181                                                Function<S, L> labelMapper) {
182                                        
183                                        FittedPredictor<F,L,E> thisPredictor = Fitter.this.fit(samples, featureMapper, labelMapper);                                    
184                                        FittedPredictor<F, L, E> otherPredictor = other.fit(
185                                                        samples, 
186                                                        featureMapper, 
187                                                        s -> {
188                                                                L label = labelMapper.apply(s);
189                                                                L prediction = thisPredictor.predict(featureMapper.apply(s));
190                                                                return subtract.apply(label, prediction);
191                                                        });
192                                                                                
193                                        return new FittedPredictor<F,L,E>() {
194                                                
195                                                @Override
196                                                public L predict(F feature) {
197                                                        L thisPrediction = thisPredictor.predict(feature);
198                                                        L otherPrediction = otherPredictor.predict(feature);
199                                                        return add.apply(thisPrediction, otherPrediction);
200                                                }
201
202                                                @Override
203                                                public Mono<L> predictAsync(F input) {
204                                                        Mono<L> thisPrediction = thisPredictor.predictAsync(input);
205                                                        Mono<L> otherPrediction = otherPredictor.predictAsync(input);                                                   
206                                                        return Mono.zip(thisPrediction, otherPrediction).map(tuple -> add.apply(tuple.getT1(), tuple.getT2()));
207                                                }
208
209                                                @Override
210                                                public E getError() {
211                                                        return errorComputer == null ? null : errorComputer.computeError(this, samples, featureMapper, labelMapper);
212                                                }
213                                                
214                                        };
215                                }
216                                
217                                @Override
218                                public <S> Mono<FittedPredictor<F, L, E>> fitAsync(
219                                                Flux<S> samples, 
220                                                Function<S, Mono<F>> featureMapper,
221                                                Function<S, Mono<L>> labelMapper) {
222                                        
223                                        throw new UnsupportedOperationException("Implement me!");
224                                        
225//                                      Collection<S> samplesSoFar = Collections.synchronizedCollection(new ArrayList<>());
226//                                      samples.subscribe(samplesSoFar::add);
227//                                      
228//                                      Mono<FittedPredictor<F,L,E>> thisPredictorMono = Fitter.this.fitAsync(samples, featureMapper, labelMapper);             
229//                                      Mono<FittedPredictor<F, L, E>> otherPredictorMono = thisPredictorMono.flatMap(thisPredictor -> {
230//                                              otherPredictorMono = other.fitAsync(
231//                                                              samples, 
232//                                                              featureMapper, 
233//                                                              s -> {
234//                                                                      Mono<L> label = labelMapper.apply(s);
235//                                                                      Mono<L> prediction = featureMapper.apply(s).flatMap(f -> thisPredictor.predictAsync(f));   thisPredictor.predictAsync();
236//                                                                      return subtract.apply(label, prediction);
237//                                                              });
238//                                              
239//                                      });
240//                                      
241//                                      return Mono.zip(thisPredictorMono, otherPredictorMono, (thisPredictor, otherPredictor) -> {
242//                                              
243//                                              return new FittedPredictor<F,L,E>() {
244//                                                      
245//                                                      @Override
246//                                                      public L predict(F feature) {
247//                                                              L thisPrediction = thisPredictor.predict(feature);
248//                                                              L otherPrediction = otherPredictor.predict(feature);
249//                                                              return add.apply(thisPrediction, otherPrediction);
250//                                                      }
251//
252//                                                      @Override
253//                                                      public Mono<L> predictAsync(F input) {
254//                                                              Mono<L> thisPrediction = thisPredictor.predictAsync(input);
255//                                                              Mono<L> otherPrediction = otherPredictor.predictAsync(input);                                                   
256//                                                              return Mono.zip(thisPrediction, otherPrediction).map(tuple -> add.apply(tuple.getT1(), tuple.getT2()));
257//                                                      }
258//
259//                                                      @Override
260//                                                      public E getError() {
261//                                                              return errorComputer == null ? null : errorComputer.computeError(this, samplesSoFar, featureMapper, labelMapper);
262//                                                      }
263//                                                      
264//                                              };                                              
265//                                              
266//                                      });                                     
267                                }                               
268                                
269                        };
270                        
271                }
272                        
273//              default Fitter<F,L,E> composeAsync(
274//                              Fitter<F,L,E> other, 
275//                              BiFunction<L,L,Mono<L>> add, 
276//                              BiFunction<L,L,Mono<L>> subtract,
277//                              ErrorComputer<F, L, E> errorComputer) {
278//
279//                      
280//              }
281                                
282        }       
283        
284        /**
285         * @return Fitting residual error
286         */
287        E getError();
288        
289        
290        @Override
291        default <G> FittedPredictor<G,L,E> adaptFeature(Function<G,F> featureMapper) {
292                
293                return new FittedPredictor<G,L,E>() {
294                        
295                        @Override
296                        public L predict(G feature) {
297                                return FittedPredictor.this.predict(featureMapper.apply(feature));
298                        }
299
300                        @Override
301                        public Mono<L> predictAsync(G feature) {
302                                return FittedPredictor.this.predictAsync(featureMapper.apply(feature));
303                        }
304                        
305                        @Override
306                        public E getError() {
307                                return FittedPredictor.this.getError();
308                        }
309                        
310                };
311                
312        }       
313                
314        default <G> FittedPredictor<G,L,E> adaptFeatureAsync(Function<G,Mono<F>> featureMapper) {
315                
316                return new FittedPredictor<G,L,E>() {
317
318                        @Override
319                        public Mono<L> predictAsync(G feature) {
320                                return featureMapper.apply(feature).flatMap(FittedPredictor.this::predictAsync);
321                        }
322                        
323                        @Override
324                        public E getError() {
325                                return FittedPredictor.this.getError();
326                        }
327                        
328                };
329                
330        }       
331        
332        default <M> FittedPredictor<F,M,E> adaptLabel(Function<L,M> labelMapper) {
333                
334                return new FittedPredictor<F,M,E>() {
335                        
336                        @Override
337                        public M predict(F feature) {
338                                return labelMapper.apply(FittedPredictor.this.predict(feature));
339                        }
340
341                        @Override
342                        public Mono<M> predictAsync(F feature) {
343                                return FittedPredictor.this.predictAsync(feature).map(labelMapper);
344                        }
345                        
346                        @Override
347                        public E getError() {
348                                return FittedPredictor.this.getError();
349                        }
350                        
351                };
352                
353        }       
354                
355        default <M> FittedPredictor<F,M,E> adaptLabelAsync(Function<L,Mono<M>> labelMapper) {
356                
357                return new FittedPredictor<F,M,E>() {
358
359                        @Override
360                        public Mono<M> predictAsync(F feature) {                                
361                                return FittedPredictor.this.predictAsync(feature).flatMap(labelMapper);
362                        }
363                        
364                        @Override
365                        public E getError() {
366                                return FittedPredictor.this.getError();
367                        }
368                        
369                };
370                
371        }               
372
373}