001package org.nasdanika.ai;
002
003import java.util.Map;
004import java.util.function.Function;
005
006/**
007 * Caches image embeddings in a map which can be loaded and saved between runs
008 * Uses image digest as caching key
009 */
010public abstract class MapCachingEmbeddingGenerator<S,E,K> extends CachingEmbeddingGenerator<S,E,K> {
011        
012        protected Map<K, E> cache;
013        
014        /**
015         * Computes caching key from source.
016         * @param source
017         * @return
018         */
019        protected abstract K computeKey(S source);
020
021        protected MapCachingEmbeddingGenerator(EmbeddingGenerator<S, E> target, Map<K,E> cache) {
022                super(target);
023                this.cache = cache;
024        }
025        
026        /**
027         * Creates an instance which uses the provided key computer.
028         * @param <S>
029         * @param <E>
030         * @param <K>
031         * @param target
032         * @param cache
033         * @param keyComputer
034         * @return
035         */
036        public static <S,E,K> MapCachingEmbeddingGenerator<S,E,K> create(EmbeddingGenerator<S, E> target, Map<K,E> cache, Function<S,K> keyComputer) {
037                return new MapCachingEmbeddingGenerator<S, E, K>(target, cache) {
038
039                        @Override
040                        protected K computeKey(S source) {
041                                return keyComputer.apply(source);
042                        }
043                        
044                };
045        }
046        
047        @Override
048        protected E get(K key) {
049                synchronized (cache) {
050                        return cache.get(key);
051                }
052        }
053
054        @Override
055        protected void put(K key, E value) {
056                synchronized (cache) {
057                        cache.putIfAbsent(key, value);
058                }
059        }       
060
061        @Override
062        public E generate(S input) {
063                synchronized (cache) {
064                        return cache.computeIfAbsent(computeKey(input), d -> target.generate(input));
065                }
066        }
067        
068}