/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.zoo.nlp.qa;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class BertDataParser {
    private static final Gson GSON = new GsonBuilder().create();
    private static final Pattern PATTERN = Pattern.compile("(\\S+?)([.,?!])?(\\s+|$)");
    @SerializedName(value="token_to_idx")
    private Map<String, Integer> token2idx;
    @SerializedName(value="idx_to_token")
    private List<String> idx2token;

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public static BertDataParser parse(InputStream is) {
        try (InputStreamReader reader = new InputStreamReader(is, StandardCharsets.UTF_8);){
            BertDataParser bertDataParser = (BertDataParser)GSON.fromJson((Reader)reader, BertDataParser.class);
            return bertDataParser;
        }
        catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    public static List<String> tokenizer(String input) {
        LinkedList<String> ret = new LinkedList<String>();
        Matcher m = PATTERN.matcher(input);
        while (m.find()) {
            ret.add(m.group(1));
            String token = m.group(2);
            if (token == null) continue;
            ret.add(token);
        }
        return ret;
    }

    public static <E> List<E> pad(List<E> tokens, E padItem, int num) {
        if (tokens.size() >= num) {
            return tokens;
        }
        ArrayList<E> padded = new ArrayList<E>(num);
        padded.addAll(tokens);
        for (int i = tokens.size(); i < num; ++i) {
            padded.add(padItem);
        }
        return padded;
    }

    public static List<Float> getTokenTypes(List<String> question, List<String> answer, int seqLength) {
        List<Object> qaEmbedded = new ArrayList();
        qaEmbedded = BertDataParser.pad(qaEmbedded, Float.valueOf(0.0f), question.size() + 2);
        qaEmbedded.addAll(BertDataParser.pad(new ArrayList(), Float.valueOf(1.0f), answer.size()));
        return BertDataParser.pad(qaEmbedded, Float.valueOf(0.0f), seqLength);
    }

    public static List<String> formTokens(List<String> question, List<String> answer, int seqLength) {
        ArrayList<String> tokens = new ArrayList<String>(question);
        tokens.add("[SEP]");
        tokens.add(0, "[CLS]");
        answer.add("[SEP]");
        tokens.addAll(answer);
        tokens.add("[SEP]");
        return BertDataParser.pad(tokens, "[PAD]", seqLength);
    }

    public List<Integer> token2idx(List<String> tokens) {
        ArrayList<Integer> indexes = new ArrayList<Integer>();
        for (String token : tokens) {
            if (this.token2idx.containsKey(token)) {
                indexes.add(this.token2idx.get(token));
                continue;
            }
            indexes.add(this.token2idx.get("[UNK]"));
        }
        return indexes;
    }

    public List<String> idx2token(List<Integer> indexes) {
        ArrayList<String> tokens = new ArrayList<String>();
        for (int index : indexes) {
            tokens.add(this.idx2token.get(index));
        }
        return tokens;
    }
}

