/*
 * Decompiled with CFR 0.152.
 */
package org.nasdanika.ai;

import com.github.jelmerk.hnswlib.core.Index;
import com.github.jelmerk.hnswlib.core.Item;
import java.io.Serializable;
import java.lang.invoke.MethodHandle;
import java.lang.runtime.ObjectMethods;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import org.nasdanika.ai.SearchResult;
import org.nasdanika.ai.TextFloatVectorEmbeddingModel;
import reactor.core.publisher.Mono;

public interface SimilaritySearch<T, D extends Comparable<D>> {
    public List<SearchResult<D>> find(T var1, int var2);

    public Mono<List<SearchResult<D>>> findAsync(T var1, int var2);

    default public <U> SimilaritySearch<U, D> adapt(final Function<U, T> mapper, final Function<U, Mono<T>> asyncMapper) {
        return new SimilaritySearch<U, D>(){

            @Override
            public List<SearchResult<D>> find(U query, int numberOfItems) {
                return SimilaritySearch.this.find(mapper.apply(query), numberOfItems);
            }

            @Override
            public Mono<List<SearchResult<D>>> findAsync(U query, int numberOfItems) {
                return ((Mono)asyncMapper.apply(query)).flatMap(mappedQuery -> SimilaritySearch.this.findAsync(mappedQuery, numberOfItems));
            }
        };
    }

    public static <D extends Comparable<D>> SimilaritySearch<String, D> textFloatVectorEmbeddingSearch(SimilaritySearch<List<List<Float>>, D> multiVectorSearch, TextFloatVectorEmbeddingModel embeddings) {
        return multiVectorSearch.adapt(embeddings::generate, embeddings::generateAsync);
    }

    public static <D extends Comparable<D>> SimilaritySearch<List<List<Float>>, D> adapt(final SimilaritySearch<List<Float>, D> vectorSearch) {
        return new SimilaritySearch<List<List<Float>>, D>(){

            @Override
            public List<SearchResult<D>> find(List<List<Float>> query, int numberOfItems) {
                ArrayList ret = new ArrayList();
                for (List<Float> qe : query) {
                    ret.addAll(vectorSearch.find(qe, numberOfItems));
                }
                Collections.sort(ret);
                return ret.size() > numberOfItems ? ret.subList(0, numberOfItems) : ret;
            }

            @Override
            public Mono<List<SearchResult<D>>> findAsync(List<List<Float>> query, int numberOfItems) {
                ArrayList results = new ArrayList();
                for (List<Float> qe : query) {
                    results.add(vectorSearch.findAsync(qe, numberOfItems));
                }
                return Mono.zip(results, ra -> {
                    ArrayList ret = new ArrayList();
                    for (Object qe : ra) {
                        ret.addAll((List)qe);
                    }
                    Collections.sort(ret);
                    return ret.size() > numberOfItems ? ret.subList(0, numberOfItems) : ret;
                });
            }
        };
    }

    public static SimilaritySearch<List<Float>, Float> from(Index<IndexId, float[], EmbeddingsItem, Float> index) {
        return SimilaritySearch.from(index, Function.identity());
    }

    public static SimilaritySearch<List<Float>, Float> from(final Index<IndexId, float[], EmbeddingsItem, Float> index, final Function<float[], float[]> normalizer) {
        return new SimilaritySearch<List<Float>, Float>(){

            @Override
            public Mono<List<SearchResult<Float>>> findAsync(List<Float> query, int numberOfItems) {
                return Mono.just(this.find(query, numberOfItems));
            }

            @Override
            public List<SearchResult<Float>> find(List<Float> query, int numberOfItems) {
                float[] fVector = new float[query.size()];
                for (int j = 0; j < fVector.length; ++j) {
                    fVector[j] = query.get(j).floatValue();
                }
                if (normalizer != null) {
                    fVector = (float[])normalizer.apply(fVector);
                }
                ArrayList<SearchResult<Float>> ret = new ArrayList<SearchResult<Float>>();
                for (com.github.jelmerk.hnswlib.core.SearchResult nearest : index.findNearest((Object)fVector, numberOfItems)) {
                    final String uri = ((EmbeddingsItem)nearest.item()).id().uri();
                    final int index2 = ((EmbeddingsItem)nearest.item()).id().index();
                    final Float distance = (Float)nearest.distance();
                    ret.add(new SearchResult<Float>(){

                        @Override
                        public String getUri() {
                            return uri;
                        }

                        @Override
                        public int getIndex() {
                            return index2;
                        }

                        @Override
                        public Float getDistance() {
                            return distance;
                        }
                    });
                }
                return ret;
            }
        };
    }

    public static final class EmbeddingsItem
    extends Record
    implements Item<IndexId, float[]> {
        private final IndexId id;
        private final float[] vector;
        private final int dimensions;

        public EmbeddingsItem(IndexId id, float[] vector, int dimensions) {
            this.id = id;
            this.vector = vector;
            this.dimensions = dimensions;
        }

        @Override
        public final String toString() {
            return ObjectMethods.bootstrap("toString", new MethodHandle[]{EmbeddingsItem.class, "id;vector;dimensions", "id", "vector", "dimensions"}, this);
        }

        @Override
        public final int hashCode() {
            return (int)ObjectMethods.bootstrap("hashCode", new MethodHandle[]{EmbeddingsItem.class, "id;vector;dimensions", "id", "vector", "dimensions"}, this);
        }

        @Override
        public final boolean equals(Object o) {
            return (boolean)ObjectMethods.bootstrap("equals", new MethodHandle[]{EmbeddingsItem.class, "id;vector;dimensions", "id", "vector", "dimensions"}, this, o);
        }

        public IndexId id() {
            return this.id;
        }

        public float[] vector() {
            return this.vector;
        }

        public int dimensions() {
            return this.dimensions;
        }
    }

    public record IndexId(String uri, int index) implements Serializable
    {
    }
}

