001package org.nasdanika.ai.cli;
002
003import com.github.jelmerk.hnswlib.core.DistanceFunction;
004import com.github.jelmerk.hnswlib.core.DistanceFunctions;
005import com.github.jelmerk.hnswlib.jdk17.Jdk17DistanceFunctions;
006import com.github.jelmerk.hnswlib.util.VectorUtils;
007
008import io.opentelemetry.api.trace.Span;
009import picocli.CommandLine.Option;
010
011public class HnswIndexBuilderFloatArgGroup extends HnswIndexBuilderArgGroup<float[], Float> {
012        
013        public enum Distance {
014                
015                BRAY_CURTIS(DistanceFunctions.FLOAT_BRAY_CURTIS_DISTANCE),
016                CANBERRA(DistanceFunctions.FLOAT_CANBERRA_DISTANCE),
017                CORRELATION(DistanceFunctions.FLOAT_CORRELATION_DISTANCE),
018                COSINE(DistanceFunctions.FLOAT_COSINE_DISTANCE),
019                EUCLIDEAN(DistanceFunctions.FLOAT_EUCLIDEAN_DISTANCE),
020                INNER_PRODUCT(DistanceFunctions.FLOAT_INNER_PRODUCT),
021                MANHATTAN(DistanceFunctions.FLOAT_MANHATTAN_DISTANCE),
022                
023                // JDK 17
024                
025                VECTOR_FLOAT_128_BRAY_CURTIS(Jdk17DistanceFunctions.VECTOR_FLOAT_128_BRAY_CURTIS_DISTANCE),
026                VECTOR_FLOAT_128_CANBERRA(Jdk17DistanceFunctions.VECTOR_FLOAT_128_CANBERRA_DISTANCE),
027                VECTOR_FLOAT_128_COSINE(Jdk17DistanceFunctions.VECTOR_FLOAT_128_COSINE_DISTANCE),
028                VECTOR_FLOAT_128_EUCLIDEAN(Jdk17DistanceFunctions.VECTOR_FLOAT_128_EUCLIDEAN_DISTANCE),
029                VECTOR_FLOAT_128_INNER_PRODUCT(Jdk17DistanceFunctions.VECTOR_FLOAT_128_INNER_PRODUCT),
030                VECTOR_FLOAT_128_MANHATTAN(Jdk17DistanceFunctions.VECTOR_FLOAT_128_MANHATTAN_DISTANCE),
031                VECTOR_FLOAT_256_BRAY_CURTIS(Jdk17DistanceFunctions.VECTOR_FLOAT_256_BRAY_CURTIS_DISTANCE),
032                VECTOR_FLOAT_256_CANBERRA(Jdk17DistanceFunctions.VECTOR_FLOAT_256_CANBERRA_DISTANCE),
033                VECTOR_FLOAT_256_COSINE(Jdk17DistanceFunctions.VECTOR_FLOAT_256_COSINE_DISTANCE),
034                VECTOR_FLOAT_256_EUCLIDEAN(Jdk17DistanceFunctions.VECTOR_FLOAT_256_EUCLIDEAN_DISTANCE),
035                VECTOR_FLOAT_256_INNER_PRODUCT(Jdk17DistanceFunctions.VECTOR_FLOAT_256_INNER_PRODUCT),
036                VECTOR_FLOAT_256_MANHATTAN(Jdk17DistanceFunctions.VECTOR_FLOAT_256_MANHATTAN_DISTANCE);                                 
037                
038                public final DistanceFunction<float[], Float> distanceFunction;
039                
040                Distance(DistanceFunction<float[], Float> distanceFunction) {
041                        this.distanceFunction = distanceFunction;
042                }
043                
044        }
045        
046        @Option( 
047                        names = "--hnsw-distance-function",
048                        description = {
049                                        "Vector distance function",
050                                        "Valid values: ${COMPLETION-CANDIDATES}",
051                                        "Default value: COSINE"
052                        })      
053        protected Distance distanceFunction = Distance.COSINE;  
054        
055        @Option( 
056                        names = "--hnsw-normalize",
057                        description = "If true, vectors are normalized"
058                        )       
059        protected boolean normalize;            
060
061        @Override
062        protected DistanceFunction<float[], Float> getDistanceFunction() {
063                return distanceFunction.distanceFunction;
064        }
065                
066        /**
067         * Normalizes the argument vector if normalize option is true
068         * @param vector
069         * @return
070         */
071        public float[] normalize(float[] vector) {
072                return normalize ? VectorUtils.normalize(vector) : vector;
073        }       
074        
075        @Override
076        public void setSpanAttributes(Span span) {
077                span.setAttribute("hnsw.distance-function", distanceFunction.name());
078                span.setAttribute("hnsw.normalize", normalize);
079        }       
080
081}