001package org.nasdanika.ai;
002
003import java.io.Serializable;
004import java.util.ArrayList;
005import java.util.Collection;
006import java.util.Collections;
007import java.util.List;
008import java.util.function.Function;
009
010import com.github.jelmerk.hnswlib.core.Index;
011import com.github.jelmerk.hnswlib.core.Item;
012
013import reactor.core.publisher.Mono;
014
015/**
016 * 
017 * @param <T> Search result and query type, e.g. <code>String</code>
018 * @param <D> Distance type, e.g. <code>Float</code>
019 */
020public interface SimilaritySearch<T,D extends Comparable<D>> {
021        
022        /**
023         * Finds items closest to the query
024         * @param query
025         * @param numberOfItems Number of items to return
026         * @return
027         */
028        List<SearchResult<D>> find(T query, int numberOfItems);
029        
030        /**
031         * Finds items closest to the query
032         * @param query
033         * @param numberOfItems Number of items to return
034         * @return
035         */
036        Mono<List<SearchResult<D>>> findAsync(T query, int numberOfItems);      
037        
038        default <U> SimilaritySearch<U,D> adapt(Function<U,T> mapper, Function<U, Mono<T>> asyncMapper) {
039                return new SimilaritySearch<U,D>() {
040
041                        @Override
042                        public List<SearchResult<D>> find(U query, int numberOfItems) {
043                                return SimilaritySearch.this.find(mapper.apply(query), numberOfItems);
044                        }
045
046                        @Override
047                        public Mono<List<SearchResult<D>>> findAsync(U query, int numberOfItems) {
048                                return asyncMapper
049                                        .apply(query)
050                                        .flatMap(mappedQuery -> SimilaritySearch.this.findAsync(mappedQuery, numberOfItems));
051                        }
052                        
053                };
054        }
055
056        /**
057         * Computes embeddings and uses them for similarity search in a multi-vector search.
058         * @param <D>
059         * @param multiVectorSearch
060         * @param embeddings
061         * @return
062         */
063        static <D extends Comparable<D>> SimilaritySearch<String,D> embeddingsSearch(
064                        SimilaritySearch<List<List<Float>>,D> multiVectorSearch, 
065                        Embeddings embeddings) {
066                return multiVectorSearch.adapt(
067                                embeddings::generate, 
068                                embeddings::generateAsync);
069        }
070                
071        /**
072         * Adapts a single vector search to multi-vector search
073         */
074        static <D extends Comparable<D>> SimilaritySearch<List<List<Float>>,D> adapt(SimilaritySearch<List<Float>,D> vectorSearch) {
075                return new SimilaritySearch<List<List<Float>>, D>() {
076
077                        @Override
078                        public List<SearchResult<D>> find(List<List<Float>> query, int numberOfItems) {
079                                List<SearchResult<D>> ret = new ArrayList<>();
080                                for (List<Float> qe: query) {
081                                        ret.addAll(vectorSearch.find(qe, numberOfItems));
082                                }
083                                Collections.sort(ret);
084                                return ret.size() > numberOfItems ? ret.subList(0, numberOfItems) : ret;
085                        }
086
087                        @SuppressWarnings("unchecked")
088                        @Override
089                        public Mono<List<SearchResult<D>>> findAsync(List<List<Float>> query, int numberOfItems) {
090                                Collection<Mono<List<SearchResult<D>>>> results = new ArrayList<>();
091                                for (List<Float> qe: query) {
092                                        results.add(vectorSearch.findAsync(qe, numberOfItems));
093                                }
094                                return Mono.zip(results, ra -> {
095                                        List<SearchResult<D>> ret = new ArrayList<>();
096                                        for (Object qe: (Object[]) ra) {
097                                                ret.addAll((List<SearchResult<D>>) qe);
098                                        }
099                                        Collections.sort(ret);
100                                        return ret.size() > numberOfItems ? ret.subList(0, numberOfItems) : ret;
101                                });
102                        }
103                        
104                }; 
105        }
106                
107        /**
108         * Index id - item URI and embedding vector index for URIs with multiple vectors/chunks.
109         */
110        record IndexId(String uri, int index) implements Serializable {}
111        
112        /**
113         * Vector index item
114         */
115        record EmbeddingsItem(IndexId id, float[] vector, int dimensions) implements Item<IndexId,float[]> {}
116        
117        static SimilaritySearch<List<Float>, Float> from(Index<IndexId, float[], EmbeddingsItem, Float> index) {
118                return from(index, Function.identity());
119        }
120        
121        static SimilaritySearch<List<Float>, Float> from(
122                        Index<IndexId, float[], EmbeddingsItem, Float> index,
123                        Function<float[], float[]> normalizer) {
124                
125                return new SimilaritySearch<List<Float>, Float>() {
126                                        
127                        @Override
128                        public Mono<List<SearchResult<Float>>> findAsync(List<Float> query, int numberOfItems) {
129                                return Mono.just(find(query, numberOfItems));
130                        }
131                        
132                        @Override
133                        public List<SearchResult<Float>> find(List<Float> query, int numberOfItems) {
134                                float[] fVector = new float[query.size()];
135                                for (int j = 0; j < fVector.length; ++j) {
136                                        fVector[j] = query.get(j);
137                                }
138                                if (normalizer != null) {
139                                        fVector = normalizer.apply(fVector);
140                                }
141                                List<SearchResult<Float>> ret = new ArrayList<>();
142                                for (com.github.jelmerk.hnswlib.core.SearchResult<EmbeddingsItem, Float> nearest: index.findNearest(fVector, numberOfItems)) {
143                                        String uri = nearest.item().id().uri();
144                                        int index = nearest.item().id().index();
145                                        Float distance = nearest.distance();
146                                        
147                                        ret.add(new SearchResult<Float>() {
148                                                
149                                                @Override
150                                                public String getUri() {
151                                                        return uri;
152                                                }
153                                                
154                                                @Override
155                                                public int getIndex() {
156                                                        return index;
157                                                }
158                                                
159                                                @Override
160                                                public Float getDistance() {
161                                                        return distance;
162                                                }
163                                                
164                                        });
165                                }
166                                return ret;
167                        }
168                        
169                };              
170        }
171                        
172}