/*
 * Decompiled with CFR 0.152.
 */
package org.nasdanika.ai.cli;

import com.github.jelmerk.hnswlib.core.hnsw.HnswIndex;
import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.common.Attributes;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.ImplicitContextKeyed;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.nasdanika.ai.SimilaritySearch;
import org.nasdanika.ai.TextFloatVectorEmbeddingModel;
import org.nasdanika.ai.TextFloatVectorEncodingChunkingEmbeddingModel;
import org.nasdanika.ai.cli.HnswIndexBuilderFloatArgGroup;
import org.nasdanika.ai.cli.TextFloatVectorEmbeddingsArgGroup;
import org.nasdanika.ai.cli.TextFloatVectorEncodingChunkingEmbeddingsArgGroup;
import org.nasdanika.capability.CapabilityLoader;
import org.nasdanika.cli.ProgressMonitorMixIn;
import org.nasdanika.cli.TelemetryCommand;
import org.nasdanika.common.ProgressMonitor;
import picocli.CommandLine;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;

public abstract class HnswIndexCommandBase
extends TelemetryCommand {
    @CommandLine.Parameters(index="0", arity="1", description={"Index output file"})
    private File output;
    @CommandLine.ArgGroup(heading="Progress monitor%n", exclusive=false)
    private ProgressMonitorMixIn progressMonitorMixIn;
    @CommandLine.ArgGroup(heading="Chunking%n", exclusive=false)
    private TextFloatVectorEncodingChunkingEmbeddingsArgGroup encodingChunkingEmbeddingsArgGroup;
    @CommandLine.ArgGroup(heading="TextFloatVectorEmbeddingModel%n", exclusive=false)
    private TextFloatVectorEmbeddingsArgGroup embeddingsArgGroup;
    @CommandLine.ArgGroup(heading="Vector index%n", exclusive=false)
    private HnswIndexBuilderFloatArgGroup hnswIndexArgGroup;

    public HnswIndexCommandBase(OpenTelemetry openTelemetry, CapabilityLoader capabilityLoader) {
        super(openTelemetry, capabilityLoader);
    }

    public Integer execute(Span commandSpan) throws Exception {
        if (this.progressMonitorMixIn == null) {
            this.progressMonitorMixIn = new ProgressMonitorMixIn();
        }
        try (ProgressMonitor progressMonitor = this.progressMonitorMixIn.createProgressMonitor(1.0);){
            if (this.encodingChunkingEmbeddingsArgGroup == null) {
                this.encodingChunkingEmbeddingsArgGroup = new TextFloatVectorEncodingChunkingEmbeddingsArgGroup();
            }
            this.encodingChunkingEmbeddingsArgGroup.setSpanAttributes(commandSpan);
            if (this.embeddingsArgGroup == null) {
                this.embeddingsArgGroup = new TextFloatVectorEmbeddingsArgGroup();
            }
            this.embeddingsArgGroup.setSpanAttributes(commandSpan);
            if (this.hnswIndexArgGroup == null) {
                this.hnswIndexArgGroup = new HnswIndexBuilderFloatArgGroup();
            }
            this.hnswIndexArgGroup.setSpanAttributes(commandSpan);
            TextFloatVectorEmbeddingModel embeddings = (TextFloatVectorEmbeddingModel)this.embeddingsArgGroup.loadOne(this.getCapabilityLoader(), progressMonitor);
            if (embeddings == null) {
                throw new CommandLine.ExecutionException(this.spec.commandLine(), "Embedding model is not available");
            }
            TextFloatVectorEncodingChunkingEmbeddingModel chunkingEmbeddings = this.encodingChunkingEmbeddingsArgGroup.createChunkingEmbeddings(embeddings);
            Function<Map.Entry, Flux> mapper = entry -> {
                Mono vectorsMono = chunkingEmbeddings.generateAsync((String)entry.getValue());
                return vectorsMono.map(vectors -> {
                    ArrayList<SimilaritySearch.EmbeddingsItem> result = new ArrayList<SimilaritySearch.EmbeddingsItem>();
                    int idx = 0;
                    for (List vector : vectors) {
                        float[] fVector = new float[vector.size()];
                        for (int j = 0; j < fVector.length; ++j) {
                            fVector[j] = ((Float)vector.get(j)).floatValue();
                        }
                        result.add(new SimilaritySearch.EmbeddingsItem(new SimilaritySearch.IndexId((String)entry.getKey(), idx++), this.hnswIndexArgGroup.normalize(fVector), vector.size()));
                    }
                    this.onGenerateEmbeddings((String)entry.getKey(), (String)entry.getValue(), (List<List<Float>>)vectors, (List<SimilaritySearch.EmbeddingsItem>)result);
                    return result;
                }).flatMapIterable(Function.identity());
            };
            List items = (List)this.getItems(commandSpan, progressMonitor).flatMap(mapper).collect(Collectors.toList()).contextWrite((ContextView)reactor.util.context.Context.of(Context.class, (Object)Context.current().with((ImplicitContextKeyed)commandSpan))).block();
            commandSpan.addEvent("items-loaded", Attributes.builder().put("size", (long)items.size()).build());
            HnswIndex index = this.hnswIndexArgGroup.buildAndAddAll(embeddings.getDimensions(), items, commandSpan);
            index.save(this.output);
            Integer n = 0;
            return n;
        }
    }

    protected void onGenerateEmbeddings(String key, String value, List<List<Float>> vectors, List<SimilaritySearch.EmbeddingsItem> embeddingsItems) {
    }

    protected abstract Flux<Map.Entry<String, String>> getItems(Span var1, ProgressMonitor var2);

    public static HnswIndex<SimilaritySearch.IndexId, float[], SimilaritySearch.EmbeddingsItem, Float> loadIndex(File file) throws IOException {
        return HnswIndex.load((File)file);
    }

    public static HnswIndex<SimilaritySearch.IndexId, float[], SimilaritySearch.EmbeddingsItem, Float> loadIndex(InputStream in) throws IOException {
        return HnswIndex.load((InputStream)in);
    }
}

