001package org.nasdanika.ai; 002 003import java.util.ArrayList; 004import java.util.List; 005 006import reactor.core.publisher.Mono; 007 008/** 009 * 010 * @param <T> a container of tokens, e.g. int[] or char[] or List<Integer> 011 */ 012public abstract class ChunkingEmbeddings<T> implements Embeddings { 013 014 private Embeddings target; 015 private int chunkSize; 016 private int overlap; 017 018 /** 019 * 020 * @param target 021 * @param chunkSize Chunk size, if non-positive, then target max input tokens is used as chunk size 022 * @param overlap 023 */ 024 protected ChunkingEmbeddings( 025 Embeddings target, 026 int chunkSize, 027 int overlap) { 028 this.target = target; 029 this.chunkSize = chunkSize > 0 ? chunkSize : target.getMaxInputTokens(); 030 this.overlap = overlap; 031 } 032 033 @Override 034 public String getProvider() { 035 return target.getProvider(); 036 } 037 038 @Override 039 public String getName() { 040 return target.getName(); 041 } 042 043 @Override 044 public String getVersion() { 045 return target.getVersion(); 046 } 047 048 @Override 049 public int getMaxInputTokens() { 050 return -1; 051 } 052 053 @Override 054 public boolean isTooLong(String input) { 055 return false; 056 } 057 058 @Override 059 public int getDimensions() { 060 return target.getDimensions(); 061 } 062 063 public List<String> chunk(String input) { 064 List<String> result = new ArrayList<>(); 065 T tokens = encode(input); 066 for (int i = 0, l = size(tokens); i < l; i += chunkSize) { 067 if (i > overlap) { 068 i -= overlap; 069 } 070 T slice = slice(tokens, i, chunkSize); 071 result.add(decode(slice)); 072 } 073 return result; 074 } 075 076 @Override 077 public Mono<List<List<Float>>> generateAsync(String input) { 078 List<String> chunks = chunk(input); 079 return target.generateAsync(chunks).map(chunkMap -> { 080 List<List<Float>> result = new ArrayList<>(); 081 for (String chunk: chunks) { 082 result.addAll(chunkMap.get(chunk)); 083 } 084 return result; 085 }); 086 } 087 088 @Override 089 public List<List<Float>> generate(String input) { 090 List<List<Float>> result = new ArrayList<>(); 091 for (String chunk: chunk(input)) { 092 result.addAll(target.generate(chunk)); 093 } 094 return result; 095 } 096 097 /** 098 * Encodes a string into tokens 099 * @param input 100 * @return 101 */ 102 protected abstract T encode(String input); 103 104 /** 105 * Decodes a string from an array of tokens 106 * @param tokens 107 * @return 108 */ 109 protected abstract String decode(T tokens); 110 111 protected abstract int size(T tokens); 112 113 protected abstract T slice(T tokens, int offset, int length); 114 115}