/*
 * Decompiled with CFR 0.152.
 */
package org.noear.solon.ai.embedding;

import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.noear.solon.ai.AiModel;
import org.noear.solon.ai.embedding.Embedding;
import org.noear.solon.ai.embedding.EmbeddingConfig;
import org.noear.solon.ai.embedding.EmbeddingRequest;
import org.noear.solon.ai.embedding.EmbeddingResponse;
import org.noear.solon.ai.embedding.dialect.EmbeddingDialect;
import org.noear.solon.ai.embedding.dialect.EmbeddingDialectManager;
import org.noear.solon.ai.rag.Document;
import org.noear.solon.lang.Preview;

@Preview(value="3.1")
public class EmbeddingModel
implements AiModel {
    private final EmbeddingConfig config;
    private final EmbeddingDialect dialect;

    protected EmbeddingModel(EmbeddingConfig config) {
        this.dialect = EmbeddingDialectManager.select(config);
        this.config = config;
    }

    public float[] embed(String text) throws IOException {
        EmbeddingResponse resp = this.input(text).call();
        if (resp.getError() != null) {
            throw resp.getError();
        }
        return resp.getData().get(0).getEmbedding();
    }

    public int dimensions() throws IOException {
        return this.embed("test").length;
    }

    public void embed(List<Document> documents) throws IOException {
        ArrayList<String> texts = new ArrayList<String>();
        documents.forEach(d -> texts.add(d.getContent()));
        EmbeddingResponse resp = this.input(texts).call();
        if (resp.getError() != null) {
            throw resp.getError();
        }
        List<Embedding> embeddings = resp.getData();
        for (int i = 0; i < embeddings.size(); ++i) {
            Document doc = documents.get(i);
            doc.embedding(embeddings.get(i).getEmbedding());
        }
    }

    public EmbeddingRequest input(String ... input) {
        return this.input(Arrays.asList(input));
    }

    public EmbeddingRequest input(List<String> input) {
        return new EmbeddingRequest(this.config, this.dialect, input);
    }

    public static Builder of(EmbeddingConfig config) {
        return new Builder(config);
    }

    public static Builder of(String apiUrl) {
        return new Builder(apiUrl);
    }

    public static class Builder {
        private final EmbeddingConfig config;

        public Builder(String apiUrl) {
            this.config = new EmbeddingConfig();
            this.config.setApiUrl(apiUrl);
        }

        public Builder(EmbeddingConfig config) {
            this.config = config;
        }

        public Builder apiKey(String apiKey) {
            this.config.setApiKey(apiKey);
            return this;
        }

        public Builder provider(String provider) {
            this.config.setProvider(provider);
            return this;
        }

        public Builder model(String model) {
            this.config.setModel(model);
            return this;
        }

        public Builder headerSet(String key, String value) {
            this.config.setHeader(key, value);
            return this;
        }

        public Builder timeout(Duration timeout) {
            this.config.setTimeout(timeout);
            return this;
        }

        public EmbeddingModel build() {
            return new EmbeddingModel(this.config);
        }
    }
}

