001package org.nasdanika.rag.core; 002 003import java.util.List; 004import java.util.function.Function; 005import java.util.stream.Stream; 006import java.util.stream.StreamSupport; 007 008import org.nasdanika.common.ProgressMonitor; 009 010/** 011 * Abstraction of stores which can find values by key similarity/distance 012 * @param <K> Key type, e.g. String 013 * @param <V> Value type, e.g. double[] or List<Double> 014 * @param <D> Distance type, e.g. Double. Can be Void or Boolean for key types which don't support distance/similarity computation, only equality. 015 */ 016public interface Store<K,V,D> { 017 018 interface SearchResult<V,D> extends Comparable<SearchResult<V,D>> { 019 020 V getValue(); 021 022 D getDistance(); 023 024 } 025 026 void add(K key, V value, ProgressMonitor progressMonitor); 027 028 default void add(V value, KeyExtractor<V, K> keyExtractor, ProgressMonitor progressMonitor) { 029 add(keyExtractor.extract(value, progressMonitor), value, progressMonitor); 030 } 031 032 default void addAll(Iterable<V> values, KeyExtractor<V, K> keyExtractor, ProgressMonitor progressMonitor) { 033 addAll(StreamSupport.stream(values.spliterator(), false), keyExtractor, progressMonitor); // Implementations may customize parallel behavior 034 } 035 036 default void addAll(Stream<V> values, KeyExtractor<V, K> keyExtractor, ProgressMonitor progressMonitor) { 037 values.forEach(v -> add(v, keyExtractor, progressMonitor)); 038 } 039 040// default void addAll(Iterable<V> values, Executor executor, KeyExtractor<V, K> keyExtractor, ProgressMonitor progressMonitor) { 041// throw new UnsupportedOperationException("TODO"); 042// } 043 044 /** 045 * Returns values with distances for nearest keys. Ordered by distance. 046 * @param key 047 * @param limit maximum number of elements to return 048 * @return 049 */ 050 List<SearchResult<V,D>> findNearest(K key, int limit); 051 052 default <L,U,E> Store<L,U,E> adapt( 053 Function<L,K> keyEncoder, 054 Function<U,V> valueEncoder, 055 Function<V,U> valueDecoder, 056 Function<D,E> distanceDecoder) { 057 058 return new Store<L,U,E>() { 059 060 class SearchResultAdapter implements SearchResult<U,E> { 061 062 private SearchResult<V, D> target; 063 064 SearchResultAdapter(SearchResult<V,D> target) { 065 this.target = target; 066 } 067 068 @Override 069 public int compareTo(SearchResult<U, E> o) { 070 if (o == this) { 071 return 0; 072 } 073 074 if (SearchResultAdapter.class.isInstance(o)) { 075 return target.compareTo(((SearchResultAdapter) o).target); 076 } 077 throw new IllegalArgumentException(); 078 } 079 080 @Override 081 public U getValue() { 082 return valueDecoder.apply(target.getValue()); 083 } 084 085 @Override 086 public E getDistance() { 087 return distanceDecoder.apply(target.getDistance()); 088 } 089 090 } 091 092 @Override 093 public void add(L key, U value, ProgressMonitor progressMonitor) { 094 Store.this.add(keyEncoder.apply(key), valueEncoder.apply(value), progressMonitor); 095 } 096 097 @Override 098 public List<SearchResult<U, E>> findNearest(L key, int limit) { 099 return Store.this 100 .findNearest(keyEncoder.apply(key), limit) 101 .stream() 102 .map(sr -> (SearchResult<U,E>) new SearchResultAdapter(sr)) 103 .toList(); 104 } 105 106 }; 107 108 } 109 110}