001package org.nasdanika.ai.cli; 002 003import java.io.File; 004import java.io.IOException; 005import java.io.InputStream; 006import java.util.ArrayList; 007import java.util.List; 008import java.util.Map; 009import java.util.function.Function; 010import java.util.stream.Collectors; 011 012import org.nasdanika.ai.Embeddings; 013import org.nasdanika.ai.EncodingChunkingEmbeddings; 014import org.nasdanika.ai.SimilaritySearch.EmbeddingsItem; 015import org.nasdanika.ai.SimilaritySearch.IndexId; 016import org.nasdanika.capability.CapabilityLoader; 017import org.nasdanika.cli.ProgressMonitorMixIn; 018import org.nasdanika.cli.TelemetryCommand; 019import org.nasdanika.common.ProgressMonitor; 020 021import com.github.jelmerk.hnswlib.core.hnsw.HnswIndex; 022 023import io.opentelemetry.api.OpenTelemetry; 024import io.opentelemetry.api.common.Attributes; 025import io.opentelemetry.api.trace.Span; 026import io.opentelemetry.context.Context; 027import picocli.CommandLine; 028import picocli.CommandLine.ArgGroup; 029import picocli.CommandLine.Parameters; 030import reactor.core.publisher.Flux; 031import reactor.core.publisher.Mono; 032 033/** 034 * Base command for creating vector index files by generating embeddings. 035 */ 036public abstract class HnswIndexCommandBase extends TelemetryCommand { 037 038 public HnswIndexCommandBase(OpenTelemetry openTelemetry, CapabilityLoader capabilityLoader) { 039 super(openTelemetry, capabilityLoader); 040 } 041 042 @Parameters( 043 index = "0", 044 arity = "1", 045 description = "Index output file") 046 private File output; 047 048 @ArgGroup( 049 heading = "Progress monitor%n", 050 exclusive = false) 051 private ProgressMonitorMixIn progressMonitorMixIn; 052 053 @ArgGroup( 054 heading = "Chunking%n", 055 exclusive = false) 056 private EncodingChunkingEmbeddingsArgGroup encodingChunkingEmbeddingsArgGroup; 057 058 @ArgGroup( 059 heading = "Embeddings%n", 060 exclusive = false) 061 private EmbeddingsArgGroup embeddingsArgGroup; 062 063 @ArgGroup( 064 heading = "Vector index%n", 065 exclusive = false) 066 private HnswIndexBuilderFloatArgGroup hnswIndexArgGroup; 067 068 @Override 069 public Integer execute(Span commandSpan) throws Exception { 070 if (progressMonitorMixIn == null) { 071 progressMonitorMixIn = new ProgressMonitorMixIn(); 072 } 073 try (ProgressMonitor progressMonitor = progressMonitorMixIn.createProgressMonitor(1)) { 074 if (encodingChunkingEmbeddingsArgGroup == null) { 075 encodingChunkingEmbeddingsArgGroup = new EncodingChunkingEmbeddingsArgGroup(); 076 } 077 encodingChunkingEmbeddingsArgGroup.setSpanAttributes(commandSpan); 078 if (embeddingsArgGroup == null) { 079 embeddingsArgGroup = new EmbeddingsArgGroup(); 080 } 081 embeddingsArgGroup.setSpanAttributes(commandSpan); 082 if (hnswIndexArgGroup == null) { 083 hnswIndexArgGroup = new HnswIndexBuilderFloatArgGroup(); 084 } 085 hnswIndexArgGroup.setSpanAttributes(commandSpan); 086 087 Embeddings embeddings = embeddingsArgGroup.loadOne(getCapabilityLoader(), progressMonitor); 088 if (embeddings == null) { 089 throw new CommandLine.ExecutionException(spec.commandLine(), "Embedding model is not available"); 090 } 091 092 EncodingChunkingEmbeddings chunkingEmbeddings = encodingChunkingEmbeddingsArgGroup.createChunkingEmbeddings(embeddings); 093 094 Function<Map.Entry<String,String>, Flux<EmbeddingsItem>> mapper = entry -> { 095 Mono<List<List<Float>>> vectorsMono = chunkingEmbeddings.generateAsync(entry.getValue()); 096 return vectorsMono.map(vectors -> { 097 List<EmbeddingsItem> result = new ArrayList<>(); 098 int idx = 0; 099 for (List<Float> vector: vectors) { 100 float[] fVector = new float[vector.size()]; 101 for (int j = 0; j < fVector.length; ++j) { 102 fVector[j] = vector.get(j); 103 } 104 result.add(new EmbeddingsItem( 105 new IndexId(entry.getKey(), idx++), 106 hnswIndexArgGroup.normalize(fVector), 107 vector.size())); 108 } 109 onGenerateEmbeddings(entry.getKey(), entry.getValue(), vectors, result); 110 return result; 111 }).flatMapIterable(Function.identity()); 112 }; 113 114 List<EmbeddingsItem> items = getItems(commandSpan, progressMonitor) 115 .flatMap(mapper) 116 .collect(Collectors.toList()) 117 .contextWrite(reactor.util.context.Context.of(Context.class, Context.current().with(commandSpan))) 118 .block(); 119 120 commandSpan.addEvent( 121 "items-loaded", 122 Attributes 123 .builder() 124 .put("size", items.size()) 125 .build()); 126 127 HnswIndex<IndexId, float[], EmbeddingsItem, Float> index = hnswIndexArgGroup.buildAndAddAll(embeddings.getDimensions(), items, commandSpan); 128 index.save(output); 129 130 return 0; 131 } 132 } 133 134 /** 135 * Listener method for the generation process. 136 * You may override this method for progress reporting. 137 * @param key 138 * @param value 139 * @param embeddingsItems 140 */ 141 protected void onGenerateEmbeddings(String key, String value, List<List<Float>> vectors, List<EmbeddingsItem> embeddingsItems) { 142 143 } 144 145 /** 146 * A {@link Flux} of items which are mapped to {@link EmbeddingsItem} and then stored to the index. 147 * @param commandSpan 148 * @return 149 */ 150 protected abstract Flux<Map.Entry<String,String>> getItems(Span commandSpan, ProgressMonitor progressMonitor); 151 152 public static HnswIndex<IndexId, float[], EmbeddingsItem, Float> loadIndex(File file) throws IOException { 153 return HnswIndex.load(file); 154 } 155 156 public static HnswIndex<IndexId, float[], EmbeddingsItem, Float> loadIndex(InputStream in) throws IOException { 157 return HnswIndex.load(in); 158 } 159 160}