001package org.nasdanika.ai;
002
003import java.util.LinkedHashMap;
004import java.util.List;
005import java.util.Map;
006import java.util.Map.Entry;
007
008import reactor.core.publisher.Mono;
009
010/**
011 * Embeddings "business" interface focusing on ease of use and leaving
012 * token usage reporting to implementations.
013 */
014public interface Embeddings extends Model {
015        
016        /**
017         * Embeddings requirement.
018         * String attributes match any value if null.
019         */
020        record Requirement(
021                String provider,
022                String model,
023                String version) {}
024        
025        /**
026         * 
027         * @param input
028         * @return true if the input is too long for a given model
029         */
030        boolean isTooLong(String input);
031        
032        /**
033         * @return number of dimentions
034         */
035        int getDimensions();
036        
037        /**
038         * Generates embeddings for a single string
039         * @param model
040         * @param input
041         * @return
042         */
043        default List<List<Float>> generate(String input) {
044                return generateAsync(input).block();
045        }
046        
047        /**
048         * Asynchronously generates embeddings for a single string
049         * @param model
050         * @param input
051         * @return
052         */
053        Mono<List<List<Float>>> generateAsync(String input);
054        
055        /**
056         * Batch generation
057         * @param input a list of input strings
058         * @return 
059         */
060        default Map<String, List<List<Float>>> generate(List<String> input) {
061                return generateAsync(input).block();
062        }
063        
064        /**
065         * Asynchronous batch generation
066         * @param input a list of input strings
067         * @return 
068         */
069        default Mono<Map<String, List<List<Float>>>> generateAsync(List<String> input) {
070                List<Mono<Entry<String,List<List<Float>>>>> monos = input
071                        .stream()
072                        .map(ie -> {
073                                Mono<List<List<Float>>> emb = generateAsync(ie);
074                                return emb.map(vector -> Map.entry(ie, vector));
075                        })
076                        .toList();
077                
078                return Mono.zip(monos, this::combine);
079        }
080        
081        private Map<String, List<List<Float>>> combine(Object[] elements) {
082                Map<String, List<List<Float>>> ret = new LinkedHashMap<>();
083                for (Object el: elements) {
084                        @SuppressWarnings("unchecked")
085                        Entry<String,List<List<Float>>> e = (Entry<String,List<List<Float>>>) el;
086                        ret.put(e.getKey(), e.getValue());
087                }               
088                return ret;
089        }
090
091}