001package org.nasdanika.rag.core;
002
003import java.lang.reflect.Array;
004import java.util.Collection;
005import java.util.List;
006import java.util.Map;
007
008import org.nasdanika.common.ProgressMonitor;
009
010import com.github.jelmerk.knn.Index;
011import com.github.jelmerk.knn.Item;
012import com.github.jelmerk.knn.hnsw.HnswIndex;
013
014public class IndexStore<K, V, D> implements Store<K, V, D> {
015        
016        protected Index<V, K, IndexItem<V, K>, D> index;
017
018        public IndexStore(Index<V, K, IndexItem<V,K>, D> index) {
019                this.index = index;             
020        }
021        
022        public IndexStore(
023                        HnswIndex<V, K, IndexItem<V,K>, D> index, 
024                        Iterable<Map.Entry<K, V>> entries,
025                        ProgressMonitor progressMonitor) {
026                this(index);
027                if (entries != null) {
028                        entries.forEach(entry -> add(entry.getKey(), entry.getValue(), progressMonitor));
029                }
030        }
031        
032        public static class IndexItem<V,K> implements Item<V,K> {
033                
034                private V value;
035                private K key;
036                int dimensions;
037
038                public IndexItem(V value, K key) {
039                        this.value = value;
040                        this.key = key;
041                        
042                        if (key.getClass().isArray()) {
043                                dimensions = Array.getLength(key);
044                        } else if (key instanceof Collection) {
045                                dimensions = ((Collection<?>) key).size();
046                        } else {
047                                dimensions = 1;
048                        }
049                }
050
051                @Override
052                public V id() {
053                        return value;
054                }
055
056                @Override
057                public K vector() {
058                        return key;
059                }
060
061                @Override
062                public int dimensions() {
063                        return dimensions;
064                }
065                
066        }
067        
068        protected IndexItem<V,K> createItem(K key, V value) {
069                return new IndexItem<V,K>(value, key);
070        }
071
072        @Override
073        public void add(K key, V value, ProgressMonitor progressMonitor) {
074                index.add(createItem(key, value));
075        }
076
077        @Override
078        public List<SearchResult<V, D>> findNearest(K key, int limit) {
079                
080                class SearchResultImpl implements SearchResult<V,D> {
081                        
082                        com.github.jelmerk.knn.SearchResult<IndexItem<V,K>,D> target;
083                        
084                        public SearchResultImpl(com.github.jelmerk.knn.SearchResult<IndexItem<V,K>,D> sr) {
085                                target = sr;
086                        }
087                        
088                        @Override
089                        public int compareTo(SearchResult<V, D> o) {
090                                return target.compareTo(((SearchResultImpl) o).target);
091                        }
092                        
093                        @Override
094                        public V getValue() {
095                                return target.item().id();
096                        }
097                        
098                        @Override
099                        public D getDistance() {
100                                return target.distance();
101                        }
102                        
103                }
104                
105                return index.findNearest(key, limit).stream().map(sr -> (SearchResult<V,D>) new SearchResultImpl(sr)).toList();
106        }
107
108}