001package org.nasdanika.rag.openai;
002
003import java.util.ArrayList;
004import java.util.Collections;
005import java.util.LinkedHashMap;
006import java.util.List;
007import java.util.Map;
008
009import org.nasdanika.common.ProgressMonitor;
010import org.nasdanika.rag.core.KeyExtractor;
011import org.nasdanika.rag.core.StringDoubleVectorKeyExtractor;
012import org.nasdanika.rag.core.StringMapDoubleVectorKeyExtractor;
013
014import com.azure.ai.openai.OpenAIClient;
015import com.azure.ai.openai.models.EmbeddingItem;
016import com.azure.ai.openai.models.Embeddings;
017import com.azure.ai.openai.models.EmbeddingsOptions;
018
019/**
020 * 
021 */
022public class OpenAIEmbeddingsKeyExtractor implements KeyExtractor<List<String>, List<List<Double>>> {
023        
024        private OpenAIClient client;
025        private String model;
026        private String user;
027        private String deploymentOrModelId;
028
029        public OpenAIEmbeddingsKeyExtractor(
030                        OpenAIClient client,
031                        String deploymentOrModelId,
032                        String model, 
033                        String user) {
034                this.client = client;
035                this.deploymentOrModelId = deploymentOrModelId;
036                this.model = model;
037                this.user = user;
038        }
039
040        @Override
041        public List<List<Double>> extract(List<String> value, ProgressMonitor progressMonitor) {
042                EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(value);
043                if (model != null) {
044                        embeddingsOptions.setModel(model);
045                }
046                if (user != null) {
047                        embeddingsOptions.setUser(user);
048                }
049                Embeddings embeddings = client.getEmbeddings(deploymentOrModelId, embeddingsOptions);
050                return embeddings.getData().stream().map(EmbeddingItem::getEmbedding).toList();
051        }
052        
053        public StringDoubleVectorKeyExtractor asStringDoubleVectorKeyExtractor() {              
054                return (value, progressMonitor) -> extract(Collections.singletonList(value), progressMonitor).get(0);
055        }
056        
057        public StringMapDoubleVectorKeyExtractor asStringMapDoubleVectorKeyExtractor() {
058                return (value, progressMonitor) -> {
059                        List<String> keys = new ArrayList<>();
060                        List<String> values = new ArrayList<>();
061                        value.entrySet().forEach(e -> {
062                                keys.add(e.getKey());
063                                values.add(e.getValue());
064                        });
065                        
066                        Map<String, List<Double>> result = new LinkedHashMap<>();
067                        List<List<Double>> embeddings = extract(values, progressMonitor);
068                        for (int i = 0; i < embeddings.size(); ++i) {
069                                result.put(keys.get(i), embeddings.get(i));
070                        }
071                        
072                        return result;
073                };
074        }
075        
076        @SuppressWarnings("unchecked")
077        @Override
078        public <T extends KeyExtractor<?, ?>> T adapt(Class<T> type) {
079                if (type.isInstance(this)) {
080                        return (T) this;
081                }
082                
083                if (type.isAssignableFrom(StringDoubleVectorKeyExtractor.class)) {
084                        return (T) asStringDoubleVectorKeyExtractor();
085                }
086                
087                if (type.isAssignableFrom(StringMapDoubleVectorKeyExtractor.class)) {
088                        return (T) asStringMapDoubleVectorKeyExtractor();
089                }
090
091                return KeyExtractor.super.adapt(type);
092        }
093
094}