001package org.nasdanika.rag.openai; 002 003import java.util.ArrayList; 004import java.util.Collections; 005import java.util.LinkedHashMap; 006import java.util.List; 007import java.util.Map; 008 009import org.nasdanika.common.ProgressMonitor; 010import org.nasdanika.rag.core.KeyExtractor; 011import org.nasdanika.rag.core.StringDoubleVectorKeyExtractor; 012import org.nasdanika.rag.core.StringFloatVectorKeyExtractor; 013import org.nasdanika.rag.core.StringMapDoubleVectorKeyExtractor; 014import org.nasdanika.rag.core.StringMapFloatVectorKeyExtractor; 015 016import com.azure.ai.openai.OpenAIClient; 017import com.azure.ai.openai.models.EmbeddingItem; 018import com.azure.ai.openai.models.Embeddings; 019import com.azure.ai.openai.models.EmbeddingsOptions; 020 021/** 022 * 023 */ 024public class OpenAIEmbeddingsKeyExtractor implements KeyExtractor<List<String>, List<List<Float>>> { 025 026 private OpenAIClient client; 027 private String model; 028 private String user; 029 private String deploymentOrModelId; 030 031 public OpenAIEmbeddingsKeyExtractor( 032 OpenAIClient client, 033 String deploymentOrModelId, 034 String model, 035 String user) { 036 this.client = client; 037 this.deploymentOrModelId = deploymentOrModelId; 038 this.model = model; 039 this.user = user; 040 } 041 042 @Override 043 public List<List<Float>> extract(List<String> value, ProgressMonitor progressMonitor) { 044 EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(value); 045 if (model != null) { 046 embeddingsOptions.setModel(model); 047 } 048 if (user != null) { 049 embeddingsOptions.setUser(user); 050 } 051 Embeddings embeddings = client.getEmbeddings(deploymentOrModelId, embeddingsOptions); 052 return embeddings.getData().stream().map(EmbeddingItem::getEmbedding).toList(); 053 } 054 055 public StringFloatVectorKeyExtractor asStringFloatVectorKeyExtractor() { 056 return (value, progressMonitor) -> extract(Collections.singletonList(value), progressMonitor).get(0); 057 } 058 059 public StringMapFloatVectorKeyExtractor asStringMapFloatVectorKeyExtractor() { 060 return (value, progressMonitor) -> { 061 List<String> keys = new ArrayList<>(); 062 List<String> values = new ArrayList<>(); 063 value.entrySet().forEach(e -> { 064 keys.add(e.getKey()); 065 values.add(e.getValue()); 066 }); 067 068 Map<String, List<Float>> result = new LinkedHashMap<>(); 069 List<List<Float>> embeddings = extract(values, progressMonitor); 070 for (int i = 0; i < embeddings.size(); ++i) { 071 result.put(keys.get(i), embeddings.get(i)); 072 } 073 074 return result; 075 }; 076 } 077 078 @SuppressWarnings("unchecked") 079 @Override 080 public <T extends KeyExtractor<?, ?>> T adapt(Class<T> type) { 081 if (type.isInstance(this)) { 082 return (T) this; 083 } 084 085 if (type.isAssignableFrom(StringFloatVectorKeyExtractor.class)) { 086 return (T) asStringFloatVectorKeyExtractor(); 087 } 088 089 if (type.isAssignableFrom(StringMapFloatVectorKeyExtractor.class)) { 090 return (T) asStringMapFloatVectorKeyExtractor(); 091 } 092 093 return KeyExtractor.super.adapt(type); 094 } 095 096}