001package org.nasdanika.ai;
002
003import java.util.Collection;
004import java.util.List;
005import java.util.function.Function;
006
007import reactor.core.publisher.Flux;
008import reactor.core.publisher.Mono;
009
010/**
011 * Predicts output from input 
012 * @param <F> feature(s) - predictor input
013 * @param <L> label - predictor output
014 */
015public interface Predictor<F,L> {
016        
017        
018        default L predict(F feature) {
019                return predictAsync(feature).block();
020        }
021        
022        Mono<L> predictAsync(F input);
023        
024        record Sample<F,L>(F feature, L label) {}
025        
026        /**
027         * Batch prediction
028         */
029        default List<Sample<F, L>> predict(Collection<F> input) {
030                return predictAsync(Flux.fromIterable(input)).block();
031        }
032        
033        /**
034         * Asynchronous batch generation
035         */
036        default Mono<List<Sample<F, L>>> predictAsync(Flux<F> input) {
037                return input
038                        .flatMap(ie -> {
039                                Mono<L> embMono = predictAsync(ie);
040                                return embMono.map(emb -> new Sample<>(ie, emb));
041                        })
042                        .collectList();
043        }
044        
045        default <G> Predictor<G,L> adaptFeature(Function<G,F> featureMapper) {
046                
047                return new Predictor<G,L>() {
048                        
049                        @Override
050                        public L predict(G feature) {
051                                return Predictor.this.predict(featureMapper.apply(feature));
052                        }
053
054                        @Override
055                        public Mono<L> predictAsync(G feature) {
056                                return Predictor.this.predictAsync(featureMapper.apply(feature));
057                        }
058                        
059                };
060                
061        }       
062                
063        default <G> Predictor<G,L> adaptFeatureAsync(Function<G,Mono<F>> featureMapper) {
064                
065                return new Predictor<G,L>() {
066
067                        @Override
068                        public Mono<L> predictAsync(G feature) {
069                                return featureMapper.apply(feature).flatMap(Predictor.this::predictAsync);
070                        }
071                        
072                };
073                
074        }       
075        
076        default <M> Predictor<F,M> adaptLabel(Function<L,M> labelMapper) {
077                
078                return new Predictor<F,M>() {
079                        
080                        @Override
081                        public M predict(F feature) {
082                                return labelMapper.apply(Predictor.this.predict(feature));
083                        }
084
085                        @Override
086                        public Mono<M> predictAsync(F feature) {
087                                return Predictor.this.predictAsync(feature).map(labelMapper);
088                        }
089                        
090                };
091                
092        }       
093                
094        default <M> Predictor<F,M> adaptLabelAsync(Function<L,Mono<M>> labelMapper) {
095                
096                return new Predictor<F,M>() {
097
098                        @Override
099                        public Mono<M> predictAsync(F feature) {                                
100                                return Predictor.this.predictAsync(feature).flatMap(labelMapper);
101                        }
102                        
103                };
104                
105        }       
106
107}