001package org.nasdanika.ai;
002
003import java.util.LinkedHashMap;
004import java.util.List;
005import java.util.Map;
006import java.util.Map.Entry;
007import java.util.function.Function;
008import java.util.function.Predicate;
009
010import reactor.core.publisher.Mono;
011
012/**
013 * Generates an embedding from source. 
014 * @param <S>
015 * @param <E>
016 */
017public interface EmbeddingGenerator<S,E> {
018        
019        /**
020         * {@link EmbeddingGenerator} requirement.
021         * Predicates can be null.
022         */
023        record Requirement(
024                Class<?> sourceType,
025                Class<?> embeddingType,
026                Predicate<Class<? extends EmbeddingGenerator<?,?>>> typePredicate,              
027                Predicate<EmbeddingGenerator<?,?>> predicate) {}        
028        
029        default E generate(S input) {
030                return generateAsync(input).block();
031        }
032        
033        Mono<E> generateAsync(S input);
034        
035        /**
036         * Batch generation
037         */
038        default Map<S, E> generate(List<S> input) {
039                return generateAsync(input).block();
040        }
041        
042        /**
043         * Asynchronous batch generation
044         */
045        default Mono<Map<S, E>> generateAsync(List<S> input) {
046                List<Mono<Entry<S,E>>> monos = input
047                        .stream()
048                        .map(ie -> {
049                                Mono<E> embMono = generateAsync(ie);
050                                return embMono.map(emb -> Map.entry(ie, emb));
051                        })
052                        .toList();
053                
054                return Mono.zip(monos, this::combine);
055        }
056                
057        private Map<S, E> combine(Object[] elements) {
058                Map<S, E> ret = new LinkedHashMap<>();
059                for (Object el: elements) {
060                        @SuppressWarnings("unchecked")
061                        Entry<S,E> e = (Entry<S,E>) el;
062                        ret.put(e.getKey(), e.getValue());
063                }               
064                return ret;
065        }       
066        
067        default <F> EmbeddingGenerator<S,F> then(EmbeddingGenerator<E,F> next) {
068                return new EmbeddingGenerator<S, F>() {
069
070                        @Override
071                        public Mono<F> generateAsync(S source) {
072                                return EmbeddingGenerator.this.generateAsync(source).flatMap(next::generateAsync);
073                        }
074                        
075                };
076                
077        }
078        
079        default <V> EmbeddingGenerator<V,E> adapt(Function<V,Mono<S>> mapper) {
080                
081                return new EmbeddingGenerator<V, E>() {
082
083                        @Override
084                        public Mono<E> generateAsync(V source) {
085                                return mapper.apply(source).flatMap(EmbeddingGenerator.this::generateAsync);
086                        }
087                        
088                };
089                
090        }       
091
092}