/*
 * Decompiled with CFR 0.152.
 */
package org.nasdanika.rag.openai;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
import com.azure.ai.openai.models.EmbeddingsOptions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import org.nasdanika.common.ProgressMonitor;
import org.nasdanika.rag.core.KeyExtractor;
import org.nasdanika.rag.core.StringDoubleVectorKeyExtractor;
import org.nasdanika.rag.core.StringMapDoubleVectorKeyExtractor;

public class OpenAIEmbeddingsKeyExtractor
implements KeyExtractor<List<String>, List<List<Double>>> {
    private OpenAIClient client;
    private String model;
    private String user;
    private String deploymentOrModelId;

    public OpenAIEmbeddingsKeyExtractor(OpenAIClient client, String deploymentOrModelId, String model, String user) {
        this.client = client;
        this.deploymentOrModelId = deploymentOrModelId;
        this.model = model;
        this.user = user;
    }

    public List<List<Double>> extract(List<String> value, ProgressMonitor progressMonitor) {
        EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(value);
        if (this.model != null) {
            embeddingsOptions.setModel(this.model);
        }
        if (this.user != null) {
            embeddingsOptions.setUser(this.user);
        }
        Embeddings embeddings = this.client.getEmbeddings(this.deploymentOrModelId, embeddingsOptions);
        return embeddings.getData().stream().map(EmbeddingItem::getEmbedding).toList();
    }

    public StringDoubleVectorKeyExtractor asStringDoubleVectorKeyExtractor() {
        return (value, progressMonitor) -> this.extract(Collections.singletonList(value), progressMonitor).get(0);
    }

    public StringMapDoubleVectorKeyExtractor asStringMapDoubleVectorKeyExtractor() {
        return (value, progressMonitor) -> {
            ArrayList keys = new ArrayList();
            ArrayList<String> values = new ArrayList<String>();
            value.entrySet().forEach(e -> {
                keys.add((String)e.getKey());
                values.add((String)e.getValue());
            });
            LinkedHashMap<String, List<Double>> result = new LinkedHashMap<String, List<Double>>();
            List<List<Double>> embeddings = this.extract(values, progressMonitor);
            for (int i = 0; i < embeddings.size(); ++i) {
                result.put((String)keys.get(i), embeddings.get(i));
            }
            return result;
        };
    }

    public <T extends KeyExtractor<?, ?>> T adapt(Class<T> type) {
        if (type.isInstance(this)) {
            return (T)this;
        }
        if (type.isAssignableFrom(StringDoubleVectorKeyExtractor.class)) {
            return (T)this.asStringDoubleVectorKeyExtractor();
        }
        if (type.isAssignableFrom(StringMapDoubleVectorKeyExtractor.class)) {
            return (T)this.asStringMapDoubleVectorKeyExtractor();
        }
        return (T)super.adapt(type);
    }
}

