001package org.nasdanika.rag.core;
002
003import java.util.Collection;
004import java.util.List;
005import java.util.Map;
006import java.util.stream.Stream;
007import java.util.stream.StreamSupport;
008
009import org.nasdanika.common.ProgressMonitor;
010
011/**
012 * 
013 * Simple store implementation on top of a collection of entries
014 * @param <K>
015 * @param <V>
016 * @param <D>
017 */
018public abstract class AbstractEntryStore<K,V,D> implements Store<K,V,D> {
019        
020        public abstract Collection<Map.Entry<K,V>> getEntries();
021        
022        protected abstract D distance(K a, K b);
023        
024        protected abstract int compareDistance(D a, D b);
025
026        @Override
027        public void add(K key, V value, ProgressMonitor progressMonitor) {
028                getEntries().add(Map.entry(key, value));                
029        }
030        
031        protected boolean isParallelSearch(K key) {
032                return true;
033        }
034        
035        protected boolean isParallelAddAll() {
036                return false;
037        }
038        
039        @Override
040        public void addAll(Iterable<V> values, KeyExtractor<V, K> keyExtractor, ProgressMonitor progressMonitor) {
041                addAll(StreamSupport.stream(values.spliterator(), isParallelAddAll()), keyExtractor, progressMonitor);
042        }       
043
044        @Override
045        public List<SearchResult<V, D>> findNearest(K key, int limit) {
046                Stream<Map.Entry<K,V>> entryStream = isParallelSearch(key) ? getEntries().parallelStream() : getEntries().stream();
047                
048                class SearchResultImpl implements SearchResult<V,D> {
049                        
050                        D distance;
051                        
052                        V value;
053                        
054                        public SearchResultImpl(Map.Entry<K, V> entry) {
055                                distance = distance(entry.getKey(), key);
056                                value = entry.getValue(); 
057                        }
058                        
059                        @Override
060                        public int compareTo(SearchResult<V, D> o) {
061                                return compareDistance(distance, o.getDistance());
062                        }
063                        
064                        @Override
065                        public V getValue() {
066                                return value;
067                        }
068                        
069                        @Override
070                        public D getDistance() {
071                                return distance;
072                        }
073                        
074                }
075                
076                return entryStream
077                                .map(e -> (SearchResult<V,D>) new SearchResultImpl(e))
078                                .limit(limit)
079                                .toList();
080        }
081
082}