001package org.nasdanika.ai; 002 003import java.util.Collection; 004import java.util.List; 005import java.util.function.Function; 006 007import reactor.core.publisher.Flux; 008import reactor.core.publisher.Mono; 009 010/** 011 * Predicts output from input 012 * @param <F> feature(s) - predictor input 013 * @param <L> label - predictor output 014 */ 015public interface Predictor<F,L> { 016 017 018 default L predict(F feature) { 019 return predictAsync(feature).block(); 020 } 021 022 Mono<L> predictAsync(F input); 023 024 record Sample<F,L>(F feature, L label) {} 025 026 /** 027 * Batch prediction 028 */ 029 default List<Sample<F, L>> predict(Collection<F> input) { 030 return predictAsync(Flux.fromIterable(input)).block(); 031 } 032 033 /** 034 * Asynchronous batch generation 035 */ 036 default Mono<List<Sample<F, L>>> predictAsync(Flux<F> input) { 037 return input 038 .flatMap(ie -> { 039 Mono<L> embMono = predictAsync(ie); 040 return embMono.map(emb -> new Sample<>(ie, emb)); 041 }) 042 .collectList(); 043 } 044 045 default <G> Predictor<G,L> adaptFeature(Function<G,F> featureMapper) { 046 047 return new Predictor<G,L>() { 048 049 @Override 050 public L predict(G feature) { 051 return Predictor.this.predict(featureMapper.apply(feature)); 052 } 053 054 @Override 055 public Mono<L> predictAsync(G feature) { 056 return Predictor.this.predictAsync(featureMapper.apply(feature)); 057 } 058 059 }; 060 061 } 062 063 default <G> Predictor<G,L> adaptFeatureAsync(Function<G,Mono<F>> featureMapper) { 064 065 return new Predictor<G,L>() { 066 067 @Override 068 public Mono<L> predictAsync(G feature) { 069 return featureMapper.apply(feature).flatMap(Predictor.this::predictAsync); 070 } 071 072 }; 073 074 } 075 076 default <M> Predictor<F,M> adaptLabel(Function<L,M> labelMapper) { 077 078 return new Predictor<F,M>() { 079 080 @Override 081 public M predict(F feature) { 082 return labelMapper.apply(Predictor.this.predict(feature)); 083 } 084 085 @Override 086 public Mono<M> predictAsync(F feature) { 087 return Predictor.this.predictAsync(feature).map(labelMapper); 088 } 089 090 }; 091 092 } 093 094 default <M> Predictor<F,M> adaptLabelAsync(Function<L,Mono<M>> labelMapper) { 095 096 return new Predictor<F,M>() { 097 098 @Override 099 public Mono<M> predictAsync(F feature) { 100 return Predictor.this.predictAsync(feature).flatMap(labelMapper); 101 } 102 103 }; 104 105 } 106 107}