001package org.nasdanika.ai;
002
003import java.util.Collection;
004import java.util.LinkedHashMap;
005import java.util.List;
006import java.util.Map;
007import java.util.Map.Entry;
008import java.util.Optional;
009import java.util.function.BiFunction;
010import java.util.function.BinaryOperator;
011import java.util.function.Function;
012import java.util.function.Predicate;
013import java.util.stream.Stream;
014
015import org.nasdanika.common.Composable;
016
017import reactor.core.publisher.Mono;
018
019/**
020 * Generates an embedding from source. 
021 * @param <S>
022 * @param <E>
023 */
024public interface EmbeddingGenerator<S,E> extends Composable<EmbeddingGenerator<S,E>> {
025        
026        /**
027         * {@link EmbeddingGenerator} requirement.
028         * Predicates can be null.
029         */
030        record Requirement(
031                Class<?> sourceType,
032                Class<?> embeddingType,
033                Predicate<Class<? extends EmbeddingGenerator<?,?>>> typePredicate,              
034                Predicate<EmbeddingGenerator<?,?>> predicate) {}        
035        
036        default E generate(S input) {
037                return generateAsync(input).block();
038        }
039        
040        Mono<E> generateAsync(S input);
041        
042        /**
043         * Batch generation
044         */
045        default Map<S, E> generate(Collection<S> input) {
046                return generateAsync(input).block();
047        }
048        
049        /**
050         * Asynchronous batch generation
051         */
052        default Mono<Map<S, E>> generateAsync(Collection<S> input) {
053                List<Mono<Entry<S,E>>> monos = input
054                        .stream()
055                        .map(ie -> {
056                                Mono<E> embMono = generateAsync(ie);
057                                return embMono.map(emb -> Map.entry(ie, emb));
058                        })
059                        .toList();
060                
061                return Mono.zip(monos, this::combine);
062        }
063                
064        private Map<S, E> combine(Object[] elements) {
065                Map<S, E> ret = new LinkedHashMap<>();
066                for (Object el: elements) {
067                        @SuppressWarnings("unchecked")
068                        Entry<S,E> e = (Entry<S,E>) el;
069                        ret.put(e.getKey(), e.getValue());
070                }               
071                return ret;
072        }       
073        
074        default <F> EmbeddingGenerator<S,F> then(EmbeddingGenerator<E,F> next) {
075                return new EmbeddingGenerator<S, F>() {
076
077                        @Override
078                        public Mono<F> generateAsync(S source) {
079                                return EmbeddingGenerator.this.generateAsync(source).flatMap(next::generateAsync);
080                        }
081                        
082                        @Override
083                        public F generate(S input) {
084                                return next.generate(EmbeddingGenerator.this.generate(input));
085                        }
086                        
087                };
088                
089        }
090        
091        default <V> EmbeddingGenerator<V,E> adapt(Function<V,Mono<S>> mapper) {
092                
093                return new EmbeddingGenerator<V, E>() {
094
095                        @Override
096                        public Mono<E> generateAsync(V source) {
097                                return mapper.apply(source).flatMap(EmbeddingGenerator.this::generateAsync);
098                        }
099                        
100                };
101                
102        }       
103        
104        /**
105         * Calls this embedding generator and returns its return value if it is not null and composer is null.
106         * Otherwise calls the other embedding generator and then composes results by calling the composer argument. 
107         * @param other
108         * @param
109         * @return
110         */
111        default EmbeddingGenerator<S,E> compose(EmbeddingGenerator<? super S,? extends E> other, BinaryOperator<E> composer) {
112                if (other == null) {
113                        return this;
114                }
115                
116                return new EmbeddingGenerator<S, E>() {
117
118                        @Override
119                        public Mono<E> generateAsync(S input) {
120                                Mono<E> thisResult = EmbeddingGenerator.this.generateAsync(input);
121                                if (composer == null) {                                 
122                                        return thisResult.switchIfEmpty(other.generateAsync(input));
123                                }
124                                
125                                Mono<? extends E> otherResult = other.generateAsync(input);                             
126                        Function<E, Mono<E>> transformer = a -> otherResult.map(b -> composer.apply(a, b)).defaultIfEmpty(a);
127                        
128                                return thisResult
129                                        .flatMap(transformer)
130                                        .switchIfEmpty(otherResult);
131                        }
132                        
133                        @Override
134                        public E generate(S input) {
135                                E thisResult = EmbeddingGenerator.this.generate(input);
136                                if (composer == null) {
137                                        return thisResult == null ? other.generate(input) : thisResult;
138                                }
139                                
140                                return composer.apply(thisResult, other.generate(input));
141                        }
142                        
143                };
144        }
145        
146        /**
147         * Calls this embedding generator and returns its return value if it is not null and composer is null.
148         * Otherwise calls the other embedding generator and then composes results by calling the composer argument. 
149         * @param other
150         * @param
151         * @return
152         */
153        default EmbeddingGenerator<S,E> composeAsync(EmbeddingGenerator<? super S,? extends E> other, BiFunction<? super E, ? super E, Mono<E>> composer) {
154                if (other == null) {
155                        return this;
156                }
157                
158                return new EmbeddingGenerator<S, E>() {
159
160                        @Override
161                        public Mono<E> generateAsync(S input) {
162                                Mono<E> thisResult = EmbeddingGenerator.this.generateAsync(input);
163                                if (composer == null) {                                 
164                                        return thisResult.switchIfEmpty(other.generateAsync(input));
165                                }
166                                
167                                Mono<? extends E> otherResult = other.generateAsync(input);                             
168                        Function<E, Mono<E>> transformer = a -> otherResult.flatMap(b -> composer.apply(a, b)).defaultIfEmpty(a);
169                        
170                                return thisResult
171                                        .flatMap(transformer)
172                                        .switchIfEmpty(otherResult);
173                        }
174                        
175                };
176        }
177        
178        @Override
179        default EmbeddingGenerator<S, E> compose(EmbeddingGenerator<S, E> other) {
180                return compose(other, null);
181        }
182        
183        /**
184         * @param <T> Instances of T shall implement {@link Composable}.
185         * @return Composing operator which can be use in reducing streams of {@link Composable}s to a single composeable.
186         */
187        static <S,E> EmbeddingGenerator<S,E> compose(EmbeddingGenerator<S,E> a, EmbeddingGenerator<S,E> b, BinaryOperator<E> composer) {        
188                if (a == null) {
189                        return b;
190                }
191                if (b == null) {
192                        return a;
193                }
194                return a.compose(b, composer);
195        }       
196        
197        static <S,E> Optional<EmbeddingGenerator<S, E>> reduce(Stream<EmbeddingGenerator<S, E>> stream, BinaryOperator<E> composer) {
198                return stream.reduce((a, b) -> compose(a, b, composer));
199        }
200
201}