001package org.nasdanika.ai;
002
003import java.nio.charset.Charset;
004import java.nio.charset.StandardCharsets;
005import java.security.MessageDigest;
006import java.security.NoSuchAlgorithmException;
007import java.util.Map;
008
009/**
010 * Caches text embeddings in a map which can be loaded and saved between runs
011 * Uses text digest as caching key
012 */
013public class CachingTextEmbeddingGenerator<E> extends MapCachingEmbeddingGenerator<String, E, String> implements TextEmbeddingGenerator<E> {
014        
015        private String algorithm;
016
017        public CachingTextEmbeddingGenerator(TextEmbeddingGenerator<E> target, Map<String, E> cache, String algorithm) {
018                super(target, cache);
019                this.algorithm = algorithm;
020        }
021
022        /**
023         * Uses SHA-512 algorithm
024         * @param cache
025         */
026        public CachingTextEmbeddingGenerator(TextEmbeddingGenerator<E> target, Map<String, E> cache) {
027                this(target, cache, "SHA-512");
028        }
029        
030        protected String computeKey(String input) {
031                try {
032                        MessageDigest digest = MessageDigest.getInstance(algorithm);
033                        byte[] dBytes = digest.digest(input.getBytes(getCharset()));
034                        
035                    StringBuilder sb = new StringBuilder();
036                    for (byte b : dBytes) {
037                        sb.append(String.format("%02x", b));
038                    }
039                    return sb.toString();
040                } catch (NoSuchAlgorithmException e) {
041                        throw new IllegalArgumentException("Error computing text digest: " + e, e);
042                }
043        }
044
045        protected Charset getCharset() {
046                return StandardCharsets.UTF_8;
047        }
048        
049}