001package org.nasdanika.ai; 002 003import java.util.LinkedHashMap; 004import java.util.List; 005import java.util.Map; 006import java.util.Map.Entry; 007 008import reactor.core.publisher.Mono; 009 010/** 011 * Embeddings "business" interface focusing on ease of use and leaving 012 * token usage reporting to implementations. 013 */ 014public interface Embeddings extends Model { 015 016 /** 017 * Embeddings requirement. 018 * String attributes match any value if null. 019 */ 020 record Requirement( 021 String provider, 022 String model, 023 String version) {} 024 025 /** 026 * 027 * @param input 028 * @return true if the input is too long for a given model 029 */ 030 boolean isTooLong(String input); 031 032 /** 033 * @return number of dimentions 034 */ 035 int getDimensions(); 036 037 /** 038 * Generates embeddings for a single string 039 * @param model 040 * @param input 041 * @return 042 */ 043 default List<List<Float>> generate(String input) { 044 return generateAsync(input).block(); 045 } 046 047 /** 048 * Asynchronously generates embeddings for a single string 049 * @param model 050 * @param input 051 * @return 052 */ 053 Mono<List<List<Float>>> generateAsync(String input); 054 055 /** 056 * Batch generation 057 * @param input a list of input strings 058 * @return 059 */ 060 default Map<String, List<List<Float>>> generate(List<String> input) { 061 return generateAsync(input).block(); 062 } 063 064 /** 065 * Asynchronous batch generation 066 * @param input a list of input strings 067 * @return 068 */ 069 default Mono<Map<String, List<List<Float>>>> generateAsync(List<String> input) { 070 List<Mono<Entry<String,List<List<Float>>>>> monos = input 071 .stream() 072 .map(ie -> { 073 Mono<List<List<Float>>> emb = generateAsync(ie); 074 return emb.map(vector -> Map.entry(ie, vector)); 075 }) 076 .toList(); 077 078 return Mono.zip(monos, this::combine); 079 } 080 081 private Map<String, List<List<Float>>> combine(Object[] elements) { 082 Map<String, List<List<Float>>> ret = new LinkedHashMap<>(); 083 for (Object el: elements) { 084 @SuppressWarnings("unchecked") 085 Entry<String,List<List<Float>>> e = (Entry<String,List<List<Float>>>) el; 086 ret.put(e.getKey(), e.getValue()); 087 } 088 return ret; 089 } 090 091}