001package org.nasdanika.ai;
002
003import java.util.LinkedHashMap;
004import java.util.List;
005import java.util.Map;
006import java.util.Map.Entry;
007
008import reactor.core.publisher.Mono;
009
010/**
011 * Embeddings "business" interface focusing on ease of use and leaving
012 * token usage reporting to implementations.
013 */
014public interface Embeddings extends Model {
015        
016        /**
017         * Embeddings requirement.
018         * String attributes match any value if null.
019         * Chunk size is the max input size if more than the max input size or non-positive.
020         */
021        record Requirement(
022                String provider,
023                String model,
024                String version,
025                int chunkSize,
026                int overlap) {}
027        
028        /**
029         * 
030         * @param input
031         * @return true if the input is too long for a given model
032         */
033        boolean isTooLong(String input);
034        
035        /**
036         * @return number of dimentions
037         */
038        int getDimensions();
039        
040        /**
041         * Generates embeddings for a single string
042         * @param model
043         * @param input
044         * @return
045         */
046        default List<List<Float>> generate(String input) {
047                return generateAsync(input).block();
048        }
049        
050        /**
051         * Asynchronously generates embeddings for a single string
052         * @param model
053         * @param input
054         * @return
055         */
056        Mono<List<List<Float>>> generateAsync(String input);
057        
058        /**
059         * Batch generation
060         * @param input a list of input strings
061         * @return 
062         */
063        default Map<String, List<List<Float>>> generate(List<String> input) {
064                return generateAsync(input).block();
065        }
066        
067        /**
068         * Asynchronous batch generation
069         * @param input a list of input strings
070         * @return 
071         */
072        default Mono<Map<String, List<List<Float>>>> generateAsync(List<String> input) {
073                List<Mono<Entry<String,List<List<Float>>>>> monos = input
074                        .stream()
075                        .map(ie -> {
076                                Mono<List<List<Float>>> emb = generateAsync(ie);
077                                return emb.map(vector -> Map.entry(ie, vector));
078                        })
079                        .toList();
080                
081                return Mono.zip(monos, this::combine);
082        }
083        
084        private Map<String, List<List<Float>>> combine(Object[] elements) {
085                Map<String, List<List<Float>>> ret = new LinkedHashMap<>();
086                for (Object el: elements) {
087                        @SuppressWarnings("unchecked")
088                        Entry<String,List<List<Float>>> e = (Entry<String,List<List<Float>>>) el;
089                        ret.put(e.getKey(), e.getValue());
090                }               
091                return ret;
092        }
093
094}