001package org.nasdanika.ai; 002 003import java.util.Map; 004import java.util.function.Function; 005 006/** 007 * Caches image embeddings in a map which can be loaded and saved between runs 008 * Uses image digest as caching key 009 */ 010public abstract class MapCachingEmbeddingGenerator<S,E,K> extends CachingEmbeddingGenerator<S,E,K> { 011 012 protected Map<K, E> cache; 013 014 /** 015 * Computes caching key from source. 016 * @param source 017 * @return 018 */ 019 protected abstract K computeKey(S source); 020 021 protected MapCachingEmbeddingGenerator(EmbeddingGenerator<S, E> target, Map<K,E> cache) { 022 super(target); 023 this.cache = cache; 024 } 025 026 /** 027 * Creates an instance which uses the provided key computer. 028 * @param <S> 029 * @param <E> 030 * @param <K> 031 * @param target 032 * @param cache 033 * @param keyComputer 034 * @return 035 */ 036 public static <S,E,K> MapCachingEmbeddingGenerator<S,E,K> create(EmbeddingGenerator<S, E> target, Map<K,E> cache, Function<S,K> keyComputer) { 037 return new MapCachingEmbeddingGenerator<S, E, K>(target, cache) { 038 039 @Override 040 protected K computeKey(S source) { 041 return keyComputer.apply(source); 042 } 043 044 }; 045 } 046 047 @Override 048 protected E get(K key) { 049 synchronized (cache) { 050 return cache.get(key); 051 } 052 } 053 054 @Override 055 protected void put(K key, E value) { 056 synchronized (cache) { 057 cache.putIfAbsent(key, value); 058 } 059 } 060 061 @Override 062 public E generate(S input) { 063 synchronized (cache) { 064 return cache.computeIfAbsent(computeKey(input), d -> target.generate(input)); 065 } 066 } 067 068}