001package org.nasdanika.ai; 002 003import java.util.LinkedHashMap; 004import java.util.List; 005import java.util.Map; 006import java.util.Map.Entry; 007 008import reactor.core.publisher.Mono; 009 010/** 011 * Embeddings "business" interface focusing on ease of use and leaving 012 * token usage reporting to implementations. 013 */ 014public interface Embeddings extends Model { 015 016 /** 017 * Embeddings requirement. 018 * String attributes match any value if null. 019 * Chunk size is the max input size if more than the max input size or non-positive. 020 */ 021 record Requirement( 022 String provider, 023 String model, 024 String version, 025 int chunkSize, 026 int overlap) {} 027 028 /** 029 * 030 * @param input 031 * @return true if the input is too long for a given model 032 */ 033 boolean isTooLong(String input); 034 035 /** 036 * @return number of dimentions 037 */ 038 int getDimensions(); 039 040 /** 041 * Generates embeddings for a single string 042 * @param model 043 * @param input 044 * @return 045 */ 046 default List<List<Float>> generate(String input) { 047 return generateAsync(input).block(); 048 } 049 050 /** 051 * Asynchronously generates embeddings for a single string 052 * @param model 053 * @param input 054 * @return 055 */ 056 Mono<List<List<Float>>> generateAsync(String input); 057 058 /** 059 * Batch generation 060 * @param input a list of input strings 061 * @return 062 */ 063 default Map<String, List<List<Float>>> generate(List<String> input) { 064 return generateAsync(input).block(); 065 } 066 067 /** 068 * Asynchronous batch generation 069 * @param input a list of input strings 070 * @return 071 */ 072 default Mono<Map<String, List<List<Float>>>> generateAsync(List<String> input) { 073 List<Mono<Entry<String,List<List<Float>>>>> monos = input 074 .stream() 075 .map(ie -> { 076 Mono<List<List<Float>>> emb = generateAsync(ie); 077 return emb.map(vector -> Map.entry(ie, vector)); 078 }) 079 .toList(); 080 081 return Mono.zip(monos, this::combine); 082 } 083 084 private Map<String, List<List<Float>>> combine(Object[] elements) { 085 Map<String, List<List<Float>>> ret = new LinkedHashMap<>(); 086 for (Object el: elements) { 087 @SuppressWarnings("unchecked") 088 Entry<String,List<List<Float>>> e = (Entry<String,List<List<Float>>>) el; 089 ret.put(e.getKey(), e.getValue()); 090 } 091 return ret; 092 } 093 094}