/*
 * Decompiled with CFR 0.152.
 */
package top.aoyudi.rag.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import top.aoyudi.rag.EmbeddingGenerator;

public class DocumentEmbeddingGenerator
implements EmbeddingGenerator {
    private static final int EMBEDDING_DIMENSION = 128;
    private static final Pattern WORD_PATTERN = Pattern.compile("\\b[a-zA-Z0-9]+\\b");
    private static final int TOP_K_WORDS = 1000;
    private final Map<String, Integer> vocabulary = new HashMap<String, Integer>();

    @Override
    public float[] generate(String content) {
        List<String> words = this.preprocessText(content);
        this.updateVocabulary(words);
        return this.generateEmbedding(words);
    }

    private List<String> preprocessText(String content) {
        ArrayList<String> words = new ArrayList<String>();
        Scanner scanner = new Scanner(content);
        while (scanner.hasNextLine()) {
            String line = scanner.nextLine().toLowerCase();
            Matcher matcher = WORD_PATTERN.matcher(line);
            while (matcher.find()) {
                String word = matcher.group();
                if (word.length() <= 1) continue;
                words.add(word);
            }
        }
        scanner.close();
        return words;
    }

    private void updateVocabulary(List<String> words) {
        HashMap<String, Integer> wordCounts = new HashMap<String, Integer>();
        for (String word : words) {
            wordCounts.put(word, wordCounts.getOrDefault(word, 0) + 1);
        }
        List sortedEntries = wordCounts.entrySet().stream().sorted(Map.Entry.comparingByValue().reversed()).limit(1000L).collect(Collectors.toList());
        for (Map.Entry entry : sortedEntries) {
            String word = (String)entry.getKey();
            if (this.vocabulary.containsKey(word)) continue;
            this.vocabulary.put(word, this.vocabulary.size() % 128);
        }
    }

    private float[] generateEmbedding(List<String> words) {
        float[] embedding = new float[128];
        HashMap<String, Integer> wordCounts = new HashMap<String, Integer>();
        for (String word : words) {
            wordCounts.put(word, wordCounts.getOrDefault(word, 0) + 1);
        }
        int totalWords = words.size();
        if (totalWords == 0) {
            return embedding;
        }
        for (Map.Entry entry : wordCounts.entrySet()) {
            int index;
            String word = (String)entry.getKey();
            int count = (Integer)entry.getValue();
            if (!this.vocabulary.containsKey(word)) continue;
            int n = index = this.vocabulary.get(word).intValue();
            embedding[n] = embedding[n] + (float)count / (float)totalWords;
        }
        this.normalizeVector(embedding);
        return embedding;
    }

    private void normalizeVector(float[] vector) {
        float norm = 0.0f;
        for (float value : vector) {
            norm += value * value;
        }
        if ((norm = (float)Math.sqrt(norm)) > 0.0f) {
            int i = 0;
            while (i < vector.length) {
                int n = i++;
                vector[n] = vector[n] / norm;
            }
        }
    }
}

