001package org.nasdanika.rag.openai;
002
003import java.util.ArrayList;
004import java.util.Collections;
005import java.util.LinkedHashMap;
006import java.util.List;
007import java.util.Map;
008
009import org.nasdanika.common.ProgressMonitor;
010import org.nasdanika.rag.core.KeyExtractor;
011import org.nasdanika.rag.core.StringDoubleVectorKeyExtractor;
012import org.nasdanika.rag.core.StringFloatVectorKeyExtractor;
013import org.nasdanika.rag.core.StringMapDoubleVectorKeyExtractor;
014import org.nasdanika.rag.core.StringMapFloatVectorKeyExtractor;
015
016import com.azure.ai.openai.OpenAIClient;
017import com.azure.ai.openai.models.EmbeddingItem;
018import com.azure.ai.openai.models.Embeddings;
019import com.azure.ai.openai.models.EmbeddingsOptions;
020
021/**
022 * 
023 */
024public class OpenAIEmbeddingsKeyExtractor implements KeyExtractor<List<String>, List<List<Float>>> {
025        
026        private OpenAIClient client;
027        private String model;
028        private String user;
029        private String deploymentOrModelId;
030
031        public OpenAIEmbeddingsKeyExtractor(
032                        OpenAIClient client,
033                        String deploymentOrModelId,
034                        String model, 
035                        String user) {
036                this.client = client;
037                this.deploymentOrModelId = deploymentOrModelId;
038                this.model = model;
039                this.user = user;
040        }
041
042        @Override
043        public List<List<Float>> extract(List<String> value, ProgressMonitor progressMonitor) {
044                EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(value);
045                if (model != null) {
046                        embeddingsOptions.setModel(model);
047                }
048                if (user != null) {
049                        embeddingsOptions.setUser(user);
050                }
051                Embeddings embeddings = client.getEmbeddings(deploymentOrModelId, embeddingsOptions);
052                return embeddings.getData().stream().map(EmbeddingItem::getEmbedding).toList();
053        }
054        
055        public StringFloatVectorKeyExtractor asStringFloatVectorKeyExtractor() {                
056                return (value, progressMonitor) -> extract(Collections.singletonList(value), progressMonitor).get(0);
057        }
058        
059        public StringMapFloatVectorKeyExtractor asStringMapFloatVectorKeyExtractor() {
060                return (value, progressMonitor) -> {
061                        List<String> keys = new ArrayList<>();
062                        List<String> values = new ArrayList<>();
063                        value.entrySet().forEach(e -> {
064                                keys.add(e.getKey());
065                                values.add(e.getValue());
066                        });
067                        
068                        Map<String, List<Float>> result = new LinkedHashMap<>();
069                        List<List<Float>> embeddings = extract(values, progressMonitor);
070                        for (int i = 0; i < embeddings.size(); ++i) {
071                                result.put(keys.get(i), embeddings.get(i));
072                        }
073                        
074                        return result;
075                };
076        }
077        
078        @SuppressWarnings("unchecked")
079        @Override
080        public <T extends KeyExtractor<?, ?>> T adapt(Class<T> type) {
081                if (type.isInstance(this)) {
082                        return (T) this;
083                }
084                
085                if (type.isAssignableFrom(StringFloatVectorKeyExtractor.class)) {
086                        return (T) asStringFloatVectorKeyExtractor();
087                }
088                
089                if (type.isAssignableFrom(StringMapFloatVectorKeyExtractor.class)) {
090                        return (T) asStringMapFloatVectorKeyExtractor();
091                }
092
093                return KeyExtractor.super.adapt(type);
094        }
095
096}