001package org.nasdanika.ai; 002 003import java.io.Serializable; 004import java.util.ArrayList; 005import java.util.Collection; 006import java.util.Collections; 007import java.util.List; 008import java.util.function.Function; 009 010import com.github.jelmerk.hnswlib.core.Index; 011import com.github.jelmerk.hnswlib.core.Item; 012 013import reactor.core.publisher.Mono; 014 015/** 016 * 017 * @param <T> Search result and query type, e.g. <code>String</code> 018 * @param <D> Distance type, e.g. <code>Float</code> 019 */ 020public interface SimilaritySearch<T,D extends Comparable<D>> { 021 022 /** 023 * Finds items closest to the query 024 * @param query 025 * @param numberOfItems Number of items to return 026 * @return 027 */ 028 List<SearchResult<D>> find(T query, int numberOfItems); 029 030 /** 031 * Finds items closest to the query 032 * @param query 033 * @param numberOfItems Number of items to return 034 * @return 035 */ 036 Mono<List<SearchResult<D>>> findAsync(T query, int numberOfItems); 037 038 default <U> SimilaritySearch<U,D> adapt(Function<U,T> mapper, Function<U, Mono<T>> asyncMapper) { 039 return new SimilaritySearch<U,D>() { 040 041 @Override 042 public List<SearchResult<D>> find(U query, int numberOfItems) { 043 return SimilaritySearch.this.find(mapper.apply(query), numberOfItems); 044 } 045 046 @Override 047 public Mono<List<SearchResult<D>>> findAsync(U query, int numberOfItems) { 048 return asyncMapper 049 .apply(query) 050 .flatMap(mappedQuery -> SimilaritySearch.this.findAsync(mappedQuery, numberOfItems)); 051 } 052 053 }; 054 } 055 056 /** 057 * Computes embeddings and uses them for similarity search in a multi-vector search. 058 * @param <D> 059 * @param multiVectorSearch 060 * @param embeddings 061 * @return 062 */ 063 static <D extends Comparable<D>> SimilaritySearch<String,D> embeddingsSearch( 064 SimilaritySearch<List<List<Float>>,D> multiVectorSearch, 065 Embeddings embeddings) { 066 return multiVectorSearch.adapt( 067 embeddings::generate, 068 embeddings::generateAsync); 069 } 070 071 /** 072 * Adapts a single vector search to multi-vector search 073 */ 074 static <D extends Comparable<D>> SimilaritySearch<List<List<Float>>,D> adapt(SimilaritySearch<List<Float>,D> vectorSearch) { 075 return new SimilaritySearch<List<List<Float>>, D>() { 076 077 @Override 078 public List<SearchResult<D>> find(List<List<Float>> query, int numberOfItems) { 079 List<SearchResult<D>> ret = new ArrayList<>(); 080 for (List<Float> qe: query) { 081 ret.addAll(vectorSearch.find(qe, numberOfItems)); 082 } 083 Collections.sort(ret); 084 return ret.size() > numberOfItems ? ret.subList(0, numberOfItems) : ret; 085 } 086 087 @SuppressWarnings("unchecked") 088 @Override 089 public Mono<List<SearchResult<D>>> findAsync(List<List<Float>> query, int numberOfItems) { 090 Collection<Mono<List<SearchResult<D>>>> results = new ArrayList<>(); 091 for (List<Float> qe: query) { 092 results.add(vectorSearch.findAsync(qe, numberOfItems)); 093 } 094 return Mono.zip(results, ra -> { 095 List<SearchResult<D>> ret = new ArrayList<>(); 096 for (Object qe: (Object[]) ra) { 097 ret.addAll((List<SearchResult<D>>) qe); 098 } 099 Collections.sort(ret); 100 return ret.size() > numberOfItems ? ret.subList(0, numberOfItems) : ret; 101 }); 102 } 103 104 }; 105 } 106 107 /** 108 * Index id - item URI and embedding vector index for URIs with multiple vectors/chunks. 109 */ 110 record IndexId(String uri, int index) implements Serializable {} 111 112 /** 113 * Vector index item 114 */ 115 record EmbeddingsItem(IndexId id, float[] vector, int dimensions) implements Item<IndexId,float[]> {} 116 117 static SimilaritySearch<List<Float>, Float> from(Index<IndexId, float[], EmbeddingsItem, Float> index) { 118 return from(index, Function.identity()); 119 } 120 121 static SimilaritySearch<List<Float>, Float> from( 122 Index<IndexId, float[], EmbeddingsItem, Float> index, 123 Function<float[], float[]> normalizer) { 124 125 return new SimilaritySearch<List<Float>, Float>() { 126 127 @Override 128 public Mono<List<SearchResult<Float>>> findAsync(List<Float> query, int numberOfItems) { 129 return Mono.just(find(query, numberOfItems)); 130 } 131 132 @Override 133 public List<SearchResult<Float>> find(List<Float> query, int numberOfItems) { 134 float[] fVector = new float[query.size()]; 135 for (int j = 0; j < fVector.length; ++j) { 136 fVector[j] = query.get(j); 137 } 138 if (normalizer != null) { 139 fVector = normalizer.apply(fVector); 140 } 141 List<SearchResult<Float>> ret = new ArrayList<>(); 142 for (com.github.jelmerk.hnswlib.core.SearchResult<EmbeddingsItem, Float> nearest: index.findNearest(fVector, numberOfItems)) { 143 String uri = nearest.item().id().uri(); 144 int index = nearest.item().id().index(); 145 Float distance = nearest.distance(); 146 147 ret.add(new SearchResult<Float>() { 148 149 @Override 150 public String getUri() { 151 return uri; 152 } 153 154 @Override 155 public int getIndex() { 156 return index; 157 } 158 159 @Override 160 public Float getDistance() { 161 return distance; 162 } 163 164 }); 165 } 166 return ret; 167 } 168 169 }; 170 } 171 172}