001package org.nasdanika.ai; 002 003import java.util.Arrays; 004 005import com.knuddels.jtokkit.Encodings; 006import com.knuddels.jtokkit.api.Encoding; 007import com.knuddels.jtokkit.api.EncodingRegistry; 008import com.knuddels.jtokkit.api.EncodingType; 009import com.knuddels.jtokkit.api.IntArrayList; 010 011public class EncodingChunkingEmbeddings extends ChunkingEmbeddings<IntArrayList> { 012 013 private Encoding encoding; 014 015 public EncodingChunkingEmbeddings( 016 Embeddings target, 017 int chunkSize, 018 int overlap, 019 Encoding encoding) { 020 021 super(target, chunkSize, overlap); 022 this.encoding = encoding; 023 } 024 025 public EncodingChunkingEmbeddings( 026 Embeddings target, 027 int chunkSize, 028 int overlap, 029 EncodingType encodingType) { 030 031 super(target, chunkSize, overlap); 032 EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); 033 encoding = registry.getEncoding(encodingType); 034 } 035 036 @Override 037 protected IntArrayList encode(String input) { 038 return encoding.encode(input); 039 } 040 041 @Override 042 protected String decode(IntArrayList tokens) { 043 return encoding.decode(tokens); 044 } 045 046 @Override 047 protected int size(IntArrayList tokens) { 048 return tokens.size(); 049 } 050 051 @Override 052 protected IntArrayList slice(IntArrayList tokens, int offset, int length) { 053 int[] ia = tokens.toArray(); 054 IntArrayList slice = new IntArrayList(); 055 for (int i: Arrays.copyOfRange(ia, offset, Math.min(ia.length, offset + length))) { 056 slice.add(i); 057 } 058 return slice; 059 } 060 061}