001package org.nasdanika.ai.cli;
002
003import java.io.File;
004import java.io.IOException;
005import java.io.InputStream;
006import java.util.ArrayList;
007import java.util.List;
008import java.util.Map;
009import java.util.function.Function;
010import java.util.stream.Collectors;
011
012import org.nasdanika.ai.TextFloatVectorEmbeddingModel;
013import org.nasdanika.ai.TextFloatVectorEncodingChunkingEmbeddingModel;
014import org.nasdanika.ai.SimilaritySearch.EmbeddingsItem;
015import org.nasdanika.ai.SimilaritySearch.IndexId;
016import org.nasdanika.capability.CapabilityLoader;
017import org.nasdanika.cli.ProgressMonitorMixIn;
018import org.nasdanika.cli.TelemetryCommand;
019import org.nasdanika.common.ProgressMonitor;
020
021import com.github.jelmerk.hnswlib.core.hnsw.HnswIndex;
022
023import io.opentelemetry.api.OpenTelemetry;
024import io.opentelemetry.api.common.Attributes;
025import io.opentelemetry.api.trace.Span;
026import io.opentelemetry.context.Context;
027import picocli.CommandLine;
028import picocli.CommandLine.ArgGroup;
029import picocli.CommandLine.Parameters;
030import reactor.core.publisher.Flux;
031import reactor.core.publisher.Mono;
032
033/**
034 * Base command for creating vector index files by generating embeddings.
035 */
036public abstract class HnswIndexCommandBase extends TelemetryCommand {
037
038        public HnswIndexCommandBase(OpenTelemetry openTelemetry, CapabilityLoader capabilityLoader) {
039                super(openTelemetry, capabilityLoader);
040        }
041                
042        @Parameters(
043                index =  "0",   
044                arity = "1",
045                description = "Index output file")
046        private File output;
047        
048        @ArgGroup(
049                        heading = "Progress monitor%n",
050                        exclusive = false)
051        private ProgressMonitorMixIn progressMonitorMixIn;
052        
053        @ArgGroup(
054                        heading = "Chunking%n",
055                        exclusive = false)
056        private TextFloatVectorEncodingChunkingEmbeddingsArgGroup encodingChunkingEmbeddingsArgGroup;
057        
058        @ArgGroup(
059                        heading = "TextFloatVectorEmbeddingModel%n",
060                        exclusive = false)
061        private TextFloatVectorEmbeddingsArgGroup embeddingsArgGroup;
062                
063        @ArgGroup(
064                        heading = "Vector index%n",
065                        exclusive = false)
066        private HnswIndexBuilderFloatArgGroup hnswIndexArgGroup;        
067                
068        @Override
069        public Integer execute(Span commandSpan) throws Exception {
070                if (progressMonitorMixIn == null) {
071                        progressMonitorMixIn = new ProgressMonitorMixIn();
072                }
073                try (ProgressMonitor progressMonitor = progressMonitorMixIn.createProgressMonitor(1)) {
074                        if (encodingChunkingEmbeddingsArgGroup == null) {
075                                encodingChunkingEmbeddingsArgGroup = new TextFloatVectorEncodingChunkingEmbeddingsArgGroup();
076                        }
077                        encodingChunkingEmbeddingsArgGroup.setSpanAttributes(commandSpan);
078                        if (embeddingsArgGroup == null) {
079                                embeddingsArgGroup = new TextFloatVectorEmbeddingsArgGroup();
080                        }
081                        embeddingsArgGroup.setSpanAttributes(commandSpan);
082                        if (hnswIndexArgGroup == null) {
083                                hnswIndexArgGroup = new HnswIndexBuilderFloatArgGroup();
084                        }
085                        hnswIndexArgGroup.setSpanAttributes(commandSpan);
086                        
087                        TextFloatVectorEmbeddingModel embeddings = embeddingsArgGroup.loadOne(getCapabilityLoader(), progressMonitor);
088                        if (embeddings == null) {
089                                throw new CommandLine.ExecutionException(spec.commandLine(), "Embedding model is not available");
090                        }
091                        
092                        TextFloatVectorEncodingChunkingEmbeddingModel chunkingEmbeddings = encodingChunkingEmbeddingsArgGroup.createChunkingEmbeddings(embeddings);
093        
094                        Function<Map.Entry<String,String>, Flux<EmbeddingsItem>> mapper = entry -> {
095                                Mono<List<List<Float>>> vectorsMono = chunkingEmbeddings.generateAsync(entry.getValue());
096                                return vectorsMono.map(vectors -> {
097                                        List<EmbeddingsItem> result = new ArrayList<>();
098                                        int idx = 0;
099                                        for (List<Float> vector: vectors) {
100                                                float[] fVector = new float[vector.size()];
101                                                for (int j = 0; j < fVector.length; ++j) {
102                                                        fVector[j] = vector.get(j);
103                                                }
104                                                result.add(new EmbeddingsItem(
105                                                                new IndexId(entry.getKey(), idx++), 
106                                                                hnswIndexArgGroup.normalize(fVector), 
107                                                                vector.size()));                                                
108                                        }
109                                        onGenerateEmbeddings(entry.getKey(), entry.getValue(), vectors, result);
110                                        return result;
111                                }).flatMapIterable(Function.identity());
112                        };
113                        
114                        List<EmbeddingsItem> items = getItems(commandSpan, progressMonitor)
115                                .flatMap(mapper)
116                                .collect(Collectors.toList())
117                                .contextWrite(reactor.util.context.Context.of(Context.class, Context.current().with(commandSpan)))
118                                .block();
119                        
120                        commandSpan.addEvent(
121                                        "items-loaded", 
122                                        Attributes
123                                                .builder()
124                                                .put("size", items.size())
125                                                .build());
126        
127                        HnswIndex<IndexId, float[], EmbeddingsItem, Float> index = hnswIndexArgGroup.buildAndAddAll(embeddings.getDimensions(), items, commandSpan);
128                        index.save(output);             
129                        
130                        return 0;
131                }
132        }
133        
134        /**
135         * Listener method for the generation process.
136         * You may override this method for progress reporting. 
137         * @param key
138         * @param value
139         * @param embeddingsItems
140         */
141        protected void onGenerateEmbeddings(String key, String value, List<List<Float>> vectors, List<EmbeddingsItem> embeddingsItems) {
142                
143        }
144                
145        /**
146         * A {@link Flux} of items which are mapped to {@link EmbeddingsItem} and then stored to the index. 
147         * @param commandSpan
148         * @return
149         */
150        protected abstract Flux<Map.Entry<String,String>> getItems(Span commandSpan, ProgressMonitor progressMonitor);
151        
152        public static HnswIndex<IndexId, float[], EmbeddingsItem, Float> loadIndex(File file) throws IOException {
153                return HnswIndex.load(file);
154        }
155        
156        public static HnswIndex<IndexId, float[], EmbeddingsItem, Float> loadIndex(InputStream in) throws IOException {
157                return HnswIndex.load(in);
158        }
159
160}