001package org.nasdanika.rag.core; 002 003import java.lang.reflect.Array; 004import java.util.Collection; 005import java.util.List; 006import java.util.Map; 007 008import org.nasdanika.common.ProgressMonitor; 009 010import com.github.jelmerk.knn.Index; 011import com.github.jelmerk.knn.Item; 012import com.github.jelmerk.knn.hnsw.HnswIndex; 013 014public class IndexStore<K, V, D> implements Store<K, V, D> { 015 016 protected Index<V, K, IndexItem<V, K>, D> index; 017 018 public IndexStore(Index<V, K, IndexItem<V,K>, D> index) { 019 this.index = index; 020 } 021 022 public IndexStore( 023 HnswIndex<V, K, IndexItem<V,K>, D> index, 024 Iterable<Map.Entry<K, V>> entries, 025 ProgressMonitor progressMonitor) { 026 this(index); 027 if (entries != null) { 028 entries.forEach(entry -> add(entry.getKey(), entry.getValue(), progressMonitor)); 029 } 030 } 031 032 public static class IndexItem<V,K> implements Item<V,K> { 033 034 private V value; 035 private K key; 036 int dimensions; 037 038 public IndexItem(V value, K key) { 039 this.value = value; 040 this.key = key; 041 042 if (key.getClass().isArray()) { 043 dimensions = Array.getLength(key); 044 } else if (key instanceof Collection) { 045 dimensions = ((Collection<?>) key).size(); 046 } else { 047 dimensions = 1; 048 } 049 } 050 051 @Override 052 public V id() { 053 return value; 054 } 055 056 @Override 057 public K vector() { 058 return key; 059 } 060 061 @Override 062 public int dimensions() { 063 return dimensions; 064 } 065 066 } 067 068 protected IndexItem<V,K> createItem(K key, V value) { 069 return new IndexItem<V,K>(value, key); 070 } 071 072 @Override 073 public void add(K key, V value, ProgressMonitor progressMonitor) { 074 index.add(createItem(key, value)); 075 } 076 077 @Override 078 public List<SearchResult<V, D>> findNearest(K key, int limit) { 079 080 class SearchResultImpl implements SearchResult<V,D> { 081 082 com.github.jelmerk.knn.SearchResult<IndexItem<V,K>,D> target; 083 084 public SearchResultImpl(com.github.jelmerk.knn.SearchResult<IndexItem<V,K>,D> sr) { 085 target = sr; 086 } 087 088 @Override 089 public int compareTo(SearchResult<V, D> o) { 090 return target.compareTo(((SearchResultImpl) o).target); 091 } 092 093 @Override 094 public V getValue() { 095 return target.item().id(); 096 } 097 098 @Override 099 public D getDistance() { 100 return target.distance(); 101 } 102 103 } 104 105 return index.findNearest(key, limit).stream().map(sr -> (SearchResult<V,D>) new SearchResultImpl(sr)).toList(); 106 } 107 108}