001package org.nasdanika.rag.core; 002 003import java.util.Collection; 004import java.util.List; 005import java.util.Map; 006import java.util.stream.Stream; 007import java.util.stream.StreamSupport; 008 009import org.nasdanika.common.ProgressMonitor; 010 011/** 012 * 013 * Simple store implementation on top of a collection of entries 014 * @param <K> 015 * @param <V> 016 * @param <D> 017 */ 018public abstract class AbstractEntryStore<K,V,D> implements Store<K,V,D> { 019 020 public abstract Collection<Map.Entry<K,V>> getEntries(); 021 022 protected abstract D distance(K a, K b); 023 024 protected abstract int compareDistance(D a, D b); 025 026 @Override 027 public void add(K key, V value, ProgressMonitor progressMonitor) { 028 getEntries().add(Map.entry(key, value)); 029 } 030 031 protected boolean isParallelSearch(K key) { 032 return true; 033 } 034 035 protected boolean isParallelAddAll() { 036 return false; 037 } 038 039 @Override 040 public void addAll(Iterable<V> values, KeyExtractor<V, K> keyExtractor, ProgressMonitor progressMonitor) { 041 addAll(StreamSupport.stream(values.spliterator(), isParallelAddAll()), keyExtractor, progressMonitor); 042 } 043 044 @Override 045 public List<SearchResult<V, D>> findNearest(K key, int limit) { 046 Stream<Map.Entry<K,V>> entryStream = isParallelSearch(key) ? getEntries().parallelStream() : getEntries().stream(); 047 048 class SearchResultImpl implements SearchResult<V,D> { 049 050 D distance; 051 052 V value; 053 054 public SearchResultImpl(Map.Entry<K, V> entry) { 055 distance = distance(entry.getKey(), key); 056 value = entry.getValue(); 057 } 058 059 @Override 060 public int compareTo(SearchResult<V, D> o) { 061 return compareDistance(distance, o.getDistance()); 062 } 063 064 @Override 065 public V getValue() { 066 return value; 067 } 068 069 @Override 070 public D getDistance() { 071 return distance; 072 } 073 074 } 075 076 return entryStream 077 .map(e -> (SearchResult<V,D>) new SearchResultImpl(e)) 078 .limit(limit) 079 .toList(); 080 } 081 082}