001package org.nasdanika.ai.cli; 002 003import com.github.jelmerk.hnswlib.core.DistanceFunction; 004import com.github.jelmerk.hnswlib.core.DistanceFunctions; 005//import com.github.jelmerk.hnswlib.jdk17.Jdk17DistanceFunctions; 006import com.github.jelmerk.hnswlib.util.VectorUtils; 007 008import io.opentelemetry.api.trace.Span; 009import picocli.CommandLine.Option; 010 011public class HnswIndexBuilderFloatArgGroup extends HnswIndexBuilderArgGroup<float[], Float> { 012 013 public enum Distance { 014 015 BRAY_CURTIS(DistanceFunctions.FLOAT_BRAY_CURTIS_DISTANCE), 016 CANBERRA(DistanceFunctions.FLOAT_CANBERRA_DISTANCE), 017 CORRELATION(DistanceFunctions.FLOAT_CORRELATION_DISTANCE), 018 COSINE(DistanceFunctions.FLOAT_COSINE_DISTANCE), 019 EUCLIDEAN(DistanceFunctions.FLOAT_EUCLIDEAN_DISTANCE), 020 INNER_PRODUCT(DistanceFunctions.FLOAT_INNER_PRODUCT), 021 MANHATTAN(DistanceFunctions.FLOAT_MANHATTAN_DISTANCE); 022 023 // JDK 17 024 025// VECTOR_FLOAT_128_BRAY_CURTIS(Jdk17DistanceFunctions.VECTOR_FLOAT_128_BRAY_CURTIS_DISTANCE), 026// VECTOR_FLOAT_128_CANBERRA(Jdk17DistanceFunctions.VECTOR_FLOAT_128_CANBERRA_DISTANCE), 027// VECTOR_FLOAT_128_COSINE(Jdk17DistanceFunctions.VECTOR_FLOAT_128_COSINE_DISTANCE), 028// VECTOR_FLOAT_128_EUCLIDEAN(Jdk17DistanceFunctions.VECTOR_FLOAT_128_EUCLIDEAN_DISTANCE), 029// VECTOR_FLOAT_128_INNER_PRODUCT(Jdk17DistanceFunctions.VECTOR_FLOAT_128_INNER_PRODUCT), 030// VECTOR_FLOAT_128_MANHATTAN(Jdk17DistanceFunctions.VECTOR_FLOAT_128_MANHATTAN_DISTANCE), 031// VECTOR_FLOAT_256_BRAY_CURTIS(Jdk17DistanceFunctions.VECTOR_FLOAT_256_BRAY_CURTIS_DISTANCE), 032// VECTOR_FLOAT_256_CANBERRA(Jdk17DistanceFunctions.VECTOR_FLOAT_256_CANBERRA_DISTANCE), 033// VECTOR_FLOAT_256_COSINE(Jdk17DistanceFunctions.VECTOR_FLOAT_256_COSINE_DISTANCE), 034// VECTOR_FLOAT_256_EUCLIDEAN(Jdk17DistanceFunctions.VECTOR_FLOAT_256_EUCLIDEAN_DISTANCE), 035// VECTOR_FLOAT_256_INNER_PRODUCT(Jdk17DistanceFunctions.VECTOR_FLOAT_256_INNER_PRODUCT), 036// VECTOR_FLOAT_256_MANHATTAN(Jdk17DistanceFunctions.VECTOR_FLOAT_256_MANHATTAN_DISTANCE); 037 038 public final DistanceFunction<float[], Float> distanceFunction; 039 040 Distance(DistanceFunction<float[], Float> distanceFunction) { 041 this.distanceFunction = distanceFunction; 042 } 043 044 } 045 046 @Option( 047 names = "--hnsw-distance-function", 048 description = { 049 "Vector distance function", 050 "Valid values: ${COMPLETION-CANDIDATES}", 051 "Default value: COSINE" 052 }) 053 protected Distance distanceFunction = Distance.COSINE; 054 055 @Option( 056 names = "--hnsw-normalize", 057 description = "If true, vectors are normalized" 058 ) 059 protected boolean normalize; 060 061 @Override 062 protected DistanceFunction<float[], Float> getDistanceFunction() { 063 return distanceFunction.distanceFunction; 064 } 065 066 /** 067 * Normalizes the argument vector if normalize option is true 068 * @param vector 069 * @return 070 */ 071 public float[] normalize(float[] vector) { 072 return normalize ? VectorUtils.normalize(vector) : vector; 073 } 074 075 @Override 076 public void setSpanAttributes(Span span) { 077 span.setAttribute("hnsw.distance-function", distanceFunction.name()); 078 span.setAttribute("hnsw.normalize", normalize); 079 } 080 081}