001package org.nasdanika.ai;
002
003import java.util.ArrayList;
004import java.util.List;
005
006import reactor.core.publisher.Mono;
007
008/**
009 * 
010 * @param <T> a container of tokens, e.g. int[] or char[] or List&lt;Integer&gt;
011 */
012public abstract class ChunkingEmbeddings<T> implements Embeddings {
013        
014        private Embeddings target;
015        private int chunkSize;
016        private int overlap;
017
018        /**
019         * 
020         * @param target
021         * @param chunkSize Chunk size, if non-positive, then target max input tokens is used as chunk size
022         * @param overlap
023         */
024        protected ChunkingEmbeddings(
025                        Embeddings target,
026                        int chunkSize, 
027                        int overlap) {
028                this.target = target;
029                this.chunkSize = chunkSize > 0 ? chunkSize : target.getMaxInputTokens();
030                this.overlap = overlap;
031        }
032
033        @Override
034        public String getProvider() {
035                return target.getProvider();
036        }
037
038        @Override
039        public String getName() {
040                return target.getName();
041        }
042
043        @Override
044        public String getVersion() {
045                return target.getVersion();
046        }
047
048        @Override
049        public int getMaxInputTokens() {
050                return -1;
051        }
052
053        @Override
054        public boolean isTooLong(String input) {
055                return false;
056        }
057
058        @Override
059        public int getDimensions() {
060                return target.getDimensions();
061        }
062        
063        public List<String> chunk(String input) {
064                List<String> result = new ArrayList<>();
065                T tokens = encode(input);
066                for (int i = 0, l = size(tokens); i < l; i += chunkSize) {
067                        if (i > overlap) {
068                                i -= overlap;
069                        }
070                        T slice = slice(tokens, i, chunkSize);
071                        result.add(decode(slice));
072                }
073                return result;          
074        }
075
076        @Override
077        public Mono<List<List<Float>>> generateAsync(String input) {
078                List<String> chunks = chunk(input);
079                return target.generateAsync(chunks).map(chunkMap -> {
080                        List<List<Float>> result = new ArrayList<>();
081                        for (String chunk: chunks) {
082                                result.addAll(chunkMap.get(chunk));
083                        }
084                        return result;
085                });
086        }
087        
088        @Override
089        public List<List<Float>> generate(String input) {
090                List<List<Float>> result = new ArrayList<>();
091                for (String chunk: chunk(input)) {              
092                        result.addAll(target.generate(chunk));
093                }
094                return result;
095        }
096        
097        /**
098         * Encodes a string into tokens
099         * @param input
100         * @return
101         */
102        protected abstract T encode(String input);
103        
104        /**
105         * Decodes a string from an array of tokens
106         * @param tokens
107         * @return
108         */
109        protected abstract String decode(T tokens);
110        
111        protected abstract int size(T tokens);
112        
113        protected abstract T slice(T tokens, int offset, int length);
114        
115}