001package org.nasdanika.ai; 002 003import java.util.LinkedHashMap; 004import java.util.List; 005import java.util.Map; 006import java.util.Map.Entry; 007import java.util.function.Function; 008import java.util.function.Predicate; 009 010import reactor.core.publisher.Mono; 011 012/** 013 * Generates an embedding from source. 014 * @param <S> 015 * @param <E> 016 */ 017public interface EmbeddingGenerator<S,E> { 018 019 /** 020 * {@link EmbeddingGenerator} requirement. 021 * Predicates can be null. 022 */ 023 record Requirement( 024 Class<?> sourceType, 025 Class<?> embeddingType, 026 Predicate<Class<? extends EmbeddingGenerator<?,?>>> typePredicate, 027 Predicate<EmbeddingGenerator<?,?>> predicate) {} 028 029 default E generate(S input) { 030 return generateAsync(input).block(); 031 } 032 033 Mono<E> generateAsync(S input); 034 035 /** 036 * Batch generation 037 */ 038 default Map<S, E> generate(List<S> input) { 039 return generateAsync(input).block(); 040 } 041 042 /** 043 * Asynchronous batch generation 044 */ 045 default Mono<Map<S, E>> generateAsync(List<S> input) { 046 List<Mono<Entry<S,E>>> monos = input 047 .stream() 048 .map(ie -> { 049 Mono<E> embMono = generateAsync(ie); 050 return embMono.map(emb -> Map.entry(ie, emb)); 051 }) 052 .toList(); 053 054 return Mono.zip(monos, this::combine); 055 } 056 057 private Map<S, E> combine(Object[] elements) { 058 Map<S, E> ret = new LinkedHashMap<>(); 059 for (Object el: elements) { 060 @SuppressWarnings("unchecked") 061 Entry<S,E> e = (Entry<S,E>) el; 062 ret.put(e.getKey(), e.getValue()); 063 } 064 return ret; 065 } 066 067 default <F> EmbeddingGenerator<S,F> then(EmbeddingGenerator<E,F> next) { 068 return new EmbeddingGenerator<S, F>() { 069 070 @Override 071 public Mono<F> generateAsync(S source) { 072 return EmbeddingGenerator.this.generateAsync(source).flatMap(next::generateAsync); 073 } 074 075 }; 076 077 } 078 079 default <V> EmbeddingGenerator<V,E> adapt(Function<V,Mono<S>> mapper) { 080 081 return new EmbeddingGenerator<V, E>() { 082 083 @Override 084 public Mono<E> generateAsync(V source) { 085 return mapper.apply(source).flatMap(EmbeddingGenerator.this::generateAsync); 086 } 087 088 }; 089 090 } 091 092}