001package org.nasdanika.ai; 002 003import java.util.Collection; 004import java.util.LinkedHashMap; 005import java.util.List; 006import java.util.Map; 007import java.util.Map.Entry; 008import java.util.Optional; 009import java.util.function.BiFunction; 010import java.util.function.BinaryOperator; 011import java.util.function.Function; 012import java.util.function.Predicate; 013import java.util.stream.Stream; 014 015import org.nasdanika.common.Composable; 016 017import reactor.core.publisher.Mono; 018 019/** 020 * Generates an embedding from source. 021 * @param <S> 022 * @param <E> 023 */ 024public interface EmbeddingGenerator<S,E> extends Composable<EmbeddingGenerator<S,E>> { 025 026 /** 027 * {@link EmbeddingGenerator} requirement. 028 * Predicates can be null. 029 */ 030 record Requirement( 031 Class<?> sourceType, 032 Class<?> embeddingType, 033 Predicate<Class<? extends EmbeddingGenerator<?,?>>> typePredicate, 034 Predicate<EmbeddingGenerator<?,?>> predicate) {} 035 036 default E generate(S input) { 037 return generateAsync(input).block(); 038 } 039 040 Mono<E> generateAsync(S input); 041 042 /** 043 * Batch generation 044 */ 045 default Map<S, E> generate(Collection<S> input) { 046 return generateAsync(input).block(); 047 } 048 049 /** 050 * Asynchronous batch generation 051 */ 052 default Mono<Map<S, E>> generateAsync(Collection<S> input) { 053 List<Mono<Entry<S,E>>> monos = input 054 .stream() 055 .map(ie -> { 056 Mono<E> embMono = generateAsync(ie); 057 return embMono.map(emb -> Map.entry(ie, emb)); 058 }) 059 .toList(); 060 061 return Mono.zip(monos, this::combine); 062 } 063 064 private Map<S, E> combine(Object[] elements) { 065 Map<S, E> ret = new LinkedHashMap<>(); 066 for (Object el: elements) { 067 @SuppressWarnings("unchecked") 068 Entry<S,E> e = (Entry<S,E>) el; 069 ret.put(e.getKey(), e.getValue()); 070 } 071 return ret; 072 } 073 074 default <F> EmbeddingGenerator<S,F> then(EmbeddingGenerator<E,F> next) { 075 return new EmbeddingGenerator<S, F>() { 076 077 @Override 078 public Mono<F> generateAsync(S source) { 079 return EmbeddingGenerator.this.generateAsync(source).flatMap(next::generateAsync); 080 } 081 082 @Override 083 public F generate(S input) { 084 return next.generate(EmbeddingGenerator.this.generate(input)); 085 } 086 087 }; 088 089 } 090 091 default <V> EmbeddingGenerator<V,E> adapt(Function<V,Mono<S>> mapper) { 092 093 return new EmbeddingGenerator<V, E>() { 094 095 @Override 096 public Mono<E> generateAsync(V source) { 097 return mapper.apply(source).flatMap(EmbeddingGenerator.this::generateAsync); 098 } 099 100 }; 101 102 } 103 104 /** 105 * Calls this embedding generator and returns its return value if it is not null and composer is null. 106 * Otherwise calls the other embedding generator and then composes results by calling the composer argument. 107 * @param other 108 * @param 109 * @return 110 */ 111 default EmbeddingGenerator<S,E> compose(EmbeddingGenerator<? super S,? extends E> other, BinaryOperator<E> composer) { 112 if (other == null) { 113 return this; 114 } 115 116 return new EmbeddingGenerator<S, E>() { 117 118 @Override 119 public Mono<E> generateAsync(S input) { 120 Mono<E> thisResult = EmbeddingGenerator.this.generateAsync(input); 121 if (composer == null) { 122 return thisResult.switchIfEmpty(other.generateAsync(input)); 123 } 124 125 Mono<? extends E> otherResult = other.generateAsync(input); 126 Function<E, Mono<E>> transformer = a -> otherResult.map(b -> composer.apply(a, b)).defaultIfEmpty(a); 127 128 return thisResult 129 .flatMap(transformer) 130 .switchIfEmpty(otherResult); 131 } 132 133 @Override 134 public E generate(S input) { 135 E thisResult = EmbeddingGenerator.this.generate(input); 136 if (composer == null) { 137 return thisResult == null ? other.generate(input) : thisResult; 138 } 139 140 return composer.apply(thisResult, other.generate(input)); 141 } 142 143 }; 144 } 145 146 /** 147 * Calls this embedding generator and returns its return value if it is not null and composer is null. 148 * Otherwise calls the other embedding generator and then composes results by calling the composer argument. 149 * @param other 150 * @param 151 * @return 152 */ 153 default EmbeddingGenerator<S,E> composeAsync(EmbeddingGenerator<? super S,? extends E> other, BiFunction<? super E, ? super E, Mono<E>> composer) { 154 if (other == null) { 155 return this; 156 } 157 158 return new EmbeddingGenerator<S, E>() { 159 160 @Override 161 public Mono<E> generateAsync(S input) { 162 Mono<E> thisResult = EmbeddingGenerator.this.generateAsync(input); 163 if (composer == null) { 164 return thisResult.switchIfEmpty(other.generateAsync(input)); 165 } 166 167 Mono<? extends E> otherResult = other.generateAsync(input); 168 Function<E, Mono<E>> transformer = a -> otherResult.flatMap(b -> composer.apply(a, b)).defaultIfEmpty(a); 169 170 return thisResult 171 .flatMap(transformer) 172 .switchIfEmpty(otherResult); 173 } 174 175 }; 176 } 177 178 @Override 179 default EmbeddingGenerator<S, E> compose(EmbeddingGenerator<S, E> other) { 180 return compose(other, null); 181 } 182 183 /** 184 * @param <T> Instances of T shall implement {@link Composable}. 185 * @return Composing operator which can be use in reducing streams of {@link Composable}s to a single composeable. 186 */ 187 static <S,E> EmbeddingGenerator<S,E> compose(EmbeddingGenerator<S,E> a, EmbeddingGenerator<S,E> b, BinaryOperator<E> composer) { 188 if (a == null) { 189 return b; 190 } 191 if (b == null) { 192 return a; 193 } 194 return a.compose(b, composer); 195 } 196 197 static <S,E> Optional<EmbeddingGenerator<S, E>> reduce(Stream<EmbeddingGenerator<S, E>> stream, BinaryOperator<E> composer) { 198 return stream.reduce((a, b) -> compose(a, b, composer)); 199 } 200 201}