001package org.nasdanika.rag.openai; 002 003import java.util.ArrayList; 004import java.util.Collections; 005import java.util.LinkedHashMap; 006import java.util.List; 007import java.util.Map; 008 009import org.nasdanika.common.ProgressMonitor; 010import org.nasdanika.rag.core.KeyExtractor; 011import org.nasdanika.rag.core.StringDoubleVectorKeyExtractor; 012import org.nasdanika.rag.core.StringMapDoubleVectorKeyExtractor; 013 014import com.azure.ai.openai.OpenAIClient; 015import com.azure.ai.openai.models.EmbeddingItem; 016import com.azure.ai.openai.models.Embeddings; 017import com.azure.ai.openai.models.EmbeddingsOptions; 018 019/** 020 * 021 */ 022public class OpenAIEmbeddingsKeyExtractor implements KeyExtractor<List<String>, List<List<Double>>> { 023 024 private OpenAIClient client; 025 private String model; 026 private String user; 027 private String deploymentOrModelId; 028 029 public OpenAIEmbeddingsKeyExtractor( 030 OpenAIClient client, 031 String deploymentOrModelId, 032 String model, 033 String user) { 034 this.client = client; 035 this.deploymentOrModelId = deploymentOrModelId; 036 this.model = model; 037 this.user = user; 038 } 039 040 @Override 041 public List<List<Double>> extract(List<String> value, ProgressMonitor progressMonitor) { 042 EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(value); 043 if (model != null) { 044 embeddingsOptions.setModel(model); 045 } 046 if (user != null) { 047 embeddingsOptions.setUser(user); 048 } 049 Embeddings embeddings = client.getEmbeddings(deploymentOrModelId, embeddingsOptions); 050 return embeddings.getData().stream().map(EmbeddingItem::getEmbedding).toList(); 051 } 052 053 public StringDoubleVectorKeyExtractor asStringDoubleVectorKeyExtractor() { 054 return (value, progressMonitor) -> extract(Collections.singletonList(value), progressMonitor).get(0); 055 } 056 057 public StringMapDoubleVectorKeyExtractor asStringMapDoubleVectorKeyExtractor() { 058 return (value, progressMonitor) -> { 059 List<String> keys = new ArrayList<>(); 060 List<String> values = new ArrayList<>(); 061 value.entrySet().forEach(e -> { 062 keys.add(e.getKey()); 063 values.add(e.getValue()); 064 }); 065 066 Map<String, List<Double>> result = new LinkedHashMap<>(); 067 List<List<Double>> embeddings = extract(values, progressMonitor); 068 for (int i = 0; i < embeddings.size(); ++i) { 069 result.put(keys.get(i), embeddings.get(i)); 070 } 071 072 return result; 073 }; 074 } 075 076 @SuppressWarnings("unchecked") 077 @Override 078 public <T extends KeyExtractor<?, ?>> T adapt(Class<T> type) { 079 if (type.isInstance(this)) { 080 return (T) this; 081 } 082 083 if (type.isAssignableFrom(StringDoubleVectorKeyExtractor.class)) { 084 return (T) asStringDoubleVectorKeyExtractor(); 085 } 086 087 if (type.isAssignableFrom(StringMapDoubleVectorKeyExtractor.class)) { 088 return (T) asStringMapDoubleVectorKeyExtractor(); 089 } 090 091 return KeyExtractor.super.adapt(type); 092 } 093 094}