001package org.nasdanika.ai;
002
003import java.util.Arrays;
004
005import com.knuddels.jtokkit.Encodings;
006import com.knuddels.jtokkit.api.Encoding;
007import com.knuddels.jtokkit.api.EncodingRegistry;
008import com.knuddels.jtokkit.api.EncodingType;
009import com.knuddels.jtokkit.api.IntArrayList;
010
011public class EncodingChunkingEmbeddings extends ChunkingEmbeddings<IntArrayList> {
012        
013        private Encoding encoding;
014
015        public EncodingChunkingEmbeddings(
016                        Embeddings target,
017                        int chunkSize, 
018                        int overlap,
019                        Encoding encoding) {
020                
021                super(target, chunkSize, overlap);
022                this.encoding = encoding;
023        }
024
025        public EncodingChunkingEmbeddings(
026                        Embeddings target,
027                        int chunkSize, 
028                        int overlap,
029                        EncodingType encodingType) {
030                
031                super(target, chunkSize, overlap);
032                EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
033                encoding = registry.getEncoding(encodingType);
034        }
035        
036        @Override
037        protected IntArrayList encode(String input) {
038                return encoding.encode(input);
039        }
040
041        @Override
042        protected String decode(IntArrayList tokens) {
043                return encoding.decode(tokens);
044        }
045
046        @Override
047        protected int size(IntArrayList tokens) {
048                return tokens.size();
049        }
050
051        @Override
052        protected IntArrayList slice(IntArrayList tokens, int offset, int length) {
053                int[] ia = tokens.toArray();
054                IntArrayList slice = new IntArrayList(); 
055                for (int i: Arrays.copyOfRange(ia, offset, Math.min(ia.length, offset + length))) {
056                        slice.add(i);
057                }
058                return slice;
059        }
060
061}