/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.dependency.perceptron.transition.parser;

import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.dependency.perceptron.accessories.CoNLLReader;
import com.hankcs.hanlp.dependency.perceptron.accessories.Edge;
import com.hankcs.hanlp.dependency.perceptron.accessories.Options;
import com.hankcs.hanlp.dependency.perceptron.accessories.Pair;
import com.hankcs.hanlp.dependency.perceptron.learning.AveragedPerceptron;
import com.hankcs.hanlp.dependency.perceptron.structures.IndexMaps;
import com.hankcs.hanlp.dependency.perceptron.structures.ParserModel;
import com.hankcs.hanlp.dependency.perceptron.structures.Sentence;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.BeamElement;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.Configuration;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.Instance;
import com.hankcs.hanlp.dependency.perceptron.transition.configuration.State;
import com.hankcs.hanlp.dependency.perceptron.transition.features.FeatureExtractor;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.Action;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.ArcEager;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.BeamScorerThread;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.ParseTaggedThread;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.ParseThread;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.PartialTreeBeamScorerThread;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.TransitionBasedParser;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class KBeamArcEagerParser
extends TransitionBasedParser {
    ExecutorService executor;
    CompletionService<ArrayList<BeamElement>> pool;
    public Options options;

    public KBeamArcEagerParser(String modelPath) throws IOException, ClassNotFoundException {
        this(modelPath, Runtime.getRuntime().availableProcessors());
    }

    public KBeamArcEagerParser(String modelPath, int numOfThreads) throws IOException, ClassNotFoundException {
        this(new ParserModel(modelPath), numOfThreads);
    }

    public KBeamArcEagerParser(ParserModel parserModel, int numOfThreads) {
        this(new AveragedPerceptron(parserModel), parserModel.dependencyLabels, parserModel.shiftFeatureAveragedWeights.length, parserModel.maps, numOfThreads, parserModel.options);
    }

    public KBeamArcEagerParser(AveragedPerceptron classifier, ArrayList<Integer> dependencyRelations, int featureLength, IndexMaps maps, int numOfThreads, Options options) {
        super(classifier, dependencyRelations, featureLength, maps);
        this.executor = Executors.newFixedThreadPool(numOfThreads);
        this.pool = new ExecutorCompletionService<ArrayList<BeamElement>>(this.executor);
        this.options = options;
    }

    public Configuration parse(String[] words, String[] tags) throws ExecutionException, InterruptedException {
        return this.parse(this.maps.makeSentence(words, tags, this.options.rootFirst, this.options.lowercase), this.options.rootFirst, this.options.beamWidth, 1);
    }

    public Configuration parse(Sentence sentence) throws ExecutionException, InterruptedException {
        return this.parse(sentence, this.options.rootFirst, this.options.beamWidth, this.options.numOfThreads);
    }

    public Configuration parse(String[] words, String[] tags, boolean rootFirst, int beamWidth, int numOfThreads) throws ExecutionException, InterruptedException {
        return this.parse(this.maps.makeSentence(words, tags, this.options.rootFirst, this.options.lowercase), rootFirst, beamWidth, numOfThreads);
    }

    public Configuration parse(Sentence sentence, boolean rootFirst, int beamWidth, int numOfThreads) throws ExecutionException, InterruptedException {
        Configuration initialConfiguration = new Configuration(sentence, rootFirst);
        ArrayList<Configuration> beam = new ArrayList<Configuration>(beamWidth);
        beam.add(initialConfiguration);
        while (!ArcEager.isTerminal(beam)) {
            TreeSet<BeamElement> beamPreserver = new TreeSet<BeamElement>();
            if (numOfThreads == 1) {
                KBeamArcEagerParser.sortBeam(beam, beamPreserver, false, new Instance(sentence, new HashMap<Integer, Edge>()), beamWidth, rootFirst, this.featureLength, this.classifier, this.dependencyRelations);
            } else {
                for (int b = 0; b < beam.size(); ++b) {
                    this.pool.submit(new BeamScorerThread(true, this.classifier, beam.get(b), this.dependencyRelations, this.featureLength, b, rootFirst));
                }
                this.fetchBeamFromPool(beamWidth, beam, beamPreserver);
            }
            beam = this.commitActionInBeam(beamWidth, beam, beamPreserver);
        }
        Configuration bestConfiguration = null;
        float bestScore = Float.NEGATIVE_INFINITY;
        for (Configuration configuration : beam) {
            if (!(configuration.getScore(true) > bestScore)) continue;
            bestScore = configuration.getScore(true);
            bestConfiguration = configuration;
        }
        return bestConfiguration;
    }

    private ArrayList<Configuration> commitActionInBeam(int beamWidth, ArrayList<Configuration> beam, TreeSet<BeamElement> beamPreserver) {
        ArrayList<Configuration> repBeam = new ArrayList<Configuration>(beamWidth);
        for (BeamElement beamElement : beamPreserver.descendingSet()) {
            if (repBeam.size() >= beamWidth) break;
            int b = beamElement.index;
            int action = beamElement.action;
            int label = beamElement.label;
            float score = beamElement.score;
            Configuration newConfig = beam.get(b).clone();
            ArcEager.commitAction(action, label, score, this.dependencyRelations, newConfig);
            repBeam.add(newConfig);
        }
        beam = repBeam;
        return beam;
    }

    private void parsePartialWithOneThread(ArrayList<Configuration> beam, TreeSet<BeamElement> beamPreserver, Boolean isNonProjective, Instance instance, int beamWidth, boolean rootFirst) {
        KBeamArcEagerParser.sortBeam(beam, beamPreserver, isNonProjective, instance, beamWidth, rootFirst, this.featureLength, this.classifier, this.dependencyRelations);
        if (beamPreserver.size() == 0) {
            ParseThread.sortBeam(beam, beamPreserver, false, null, beamWidth, rootFirst, this.featureLength, this.classifier, this.dependencyRelations);
        }
    }

    private static void sortBeam(ArrayList<Configuration> beam, TreeSet<BeamElement> beamPreserver, Boolean isNonProjective, Instance instance, int beamWidth, boolean rootFirst, int featureLength, AveragedPerceptron classifier, Collection<Integer> dependencyRelations) {
        for (int b = 0; b < beam.size(); ++b) {
            float addedScore;
            float score;
            float addedScore2;
            float score2;
            Configuration configuration = beam.get(b);
            State currentState = configuration.state;
            float prevScore = configuration.score;
            boolean canShift = ArcEager.canDo(Action.Shift, currentState);
            boolean canReduce = ArcEager.canDo(Action.Reduce, currentState);
            boolean canRightArc = ArcEager.canDo(Action.RightArc, currentState);
            boolean canLeftArc = ArcEager.canDo(Action.LeftArc, currentState);
            Object[] features = FeatureExtractor.extractAllParseFeatures(configuration, featureLength);
            if (!(canShift || canReduce || canRightArc || canLeftArc || !rootFirst)) {
                beamPreserver.add(new BeamElement(prevScore, b, 4, -1));
                if (beamPreserver.size() > beamWidth) {
                    beamPreserver.pollFirst();
                }
            }
            if (canShift && (isNonProjective.booleanValue() || instance.actionCost(Action.Shift, -1, currentState) == 0)) {
                score2 = classifier.shiftScore(features, true);
                addedScore2 = score2 + prevScore;
                beamPreserver.add(new BeamElement(addedScore2, b, 0, -1));
                if (beamPreserver.size() > beamWidth) {
                    beamPreserver.pollFirst();
                }
            }
            if (canReduce && (isNonProjective.booleanValue() || instance.actionCost(Action.Reduce, -1, currentState) == 0)) {
                score2 = classifier.reduceScore(features, true);
                addedScore2 = score2 + prevScore;
                beamPreserver.add(new BeamElement(addedScore2, b, 1, -1));
                if (beamPreserver.size() > beamWidth) {
                    beamPreserver.pollFirst();
                }
            }
            if (canRightArc) {
                float[] rightArcScores = classifier.rightArcScores(features, true);
                for (int dependency : dependencyRelations) {
                    if (!isNonProjective.booleanValue() && instance.actionCost(Action.RightArc, dependency, currentState) != 0) continue;
                    score = rightArcScores[dependency];
                    addedScore = score + prevScore;
                    beamPreserver.add(new BeamElement(addedScore, b, 2, dependency));
                    if (beamPreserver.size() <= beamWidth) continue;
                    beamPreserver.pollFirst();
                }
            }
            if (!canLeftArc) continue;
            float[] leftArcScores = classifier.leftArcScores(features, true);
            for (int dependency : dependencyRelations) {
                if (!isNonProjective.booleanValue() && instance.actionCost(Action.LeftArc, dependency, currentState) != 0) continue;
                score = leftArcScores[dependency];
                addedScore = score + prevScore;
                beamPreserver.add(new BeamElement(addedScore, b, 3, dependency));
                if (beamPreserver.size() <= beamWidth) continue;
                beamPreserver.pollFirst();
            }
        }
    }

    public Configuration parsePartial(Instance instance, Sentence sentence, boolean rootFirst, int beamWidth, int numOfThreads) throws ExecutionException, InterruptedException {
        Configuration initialConfiguration = new Configuration(sentence, rootFirst);
        boolean isNonProjective = false;
        if (instance.isNonprojective()) {
            isNonProjective = true;
        }
        ArrayList<Configuration> beam = new ArrayList<Configuration>(beamWidth);
        beam.add(initialConfiguration);
        while (!ArcEager.isTerminal(beam)) {
            TreeSet<BeamElement> beamPreserver = new TreeSet<BeamElement>();
            if (numOfThreads == 1) {
                this.parsePartialWithOneThread(beam, beamPreserver, isNonProjective, instance, beamWidth, rootFirst);
            } else {
                for (int b = 0; b < beam.size(); ++b) {
                    this.pool.submit(new PartialTreeBeamScorerThread(true, this.classifier, instance, beam.get(b), this.dependencyRelations, this.featureLength, b));
                }
                this.fetchBeamFromPool(beamWidth, beam, beamPreserver);
            }
            beam = this.commitActionInBeam(beamWidth, beam, beamPreserver);
        }
        Configuration bestConfiguration = null;
        float bestScore = Float.NEGATIVE_INFINITY;
        for (Configuration configuration : beam) {
            if (!(configuration.getScore(true) > bestScore)) continue;
            bestScore = configuration.getScore(true);
            bestConfiguration = configuration;
        }
        return bestConfiguration;
    }

    private void fetchBeamFromPool(int beamWidth, ArrayList<Configuration> beam, TreeSet<BeamElement> beamPreserver) throws InterruptedException, ExecutionException {
        for (int b = 0; b < beam.size(); ++b) {
            for (BeamElement element : this.pool.take().get()) {
                beamPreserver.add(element);
                if (beamPreserver.size() <= beamWidth) continue;
                beamPreserver.pollFirst();
            }
        }
    }

    public void parseConllFile(String inputFile, String outputFile, boolean rootFirst, int beamWidth, boolean labeled, boolean lowerCased, int numThreads, boolean partial, String scorePath) throws IOException, ExecutionException, InterruptedException {
        if (numThreads == 1) {
            this.parseConllFileNoParallel(inputFile, outputFile, rootFirst, beamWidth, labeled, lowerCased, numThreads, partial, scorePath);
        } else {
            this.parseConllFileParallel(inputFile, outputFile, rootFirst, beamWidth, lowerCased, numThreads, partial, scorePath);
        }
    }

    public void parseConllFileNoParallel(String inputFile, String outputFile, boolean rootFirst, int beamWidth, boolean labeled, boolean lowerCased, int numOfThreads, boolean partial, String scorePath) throws IOException, ExecutionException, InterruptedException {
        String line;
        CoNLLReader reader = new CoNLLReader(inputFile);
        boolean addScore = false;
        if (scorePath.trim().length() > 0) {
            addScore = true;
        }
        ArrayList<Float> scoreList = new ArrayList<Float>();
        long start = System.currentTimeMillis();
        int allArcs = 0;
        int size = 0;
        BufferedWriter writer = new BufferedWriter(new FileWriter(outputFile + ".tmp"));
        int dataCount = 0;
        block0: while (true) {
            ArrayList<Instance> data = reader.readData(15000, true, labeled, rootFirst, lowerCased, this.maps);
            size += data.size();
            if (data.size() == 0) break;
            Iterator<Instance> iterator = data.iterator();
            while (true) {
                if (!iterator.hasNext()) continue block0;
                Instance instance = iterator.next();
                if (++dataCount % 100 == 0) {
                    System.err.print(dataCount + " ... ");
                }
                Configuration bestParse = partial ? this.parsePartial(instance, instance.getSentence(), rootFirst, beamWidth, numOfThreads) : this.parse(instance.getSentence(), rootFirst, beamWidth, numOfThreads);
                int[] words = instance.getSentence().getWords();
                allArcs += words.length - 1;
                if (addScore) {
                    scoreList.add(Float.valueOf(bestParse.score / (float)bestParse.sentence.size()));
                }
                this.writeParsedSentence(writer, rootFirst, bestParse, words);
            }
            break;
        }
        long end = System.currentTimeMillis();
        float each = 1.0f * (float)(end - start) / (float)size;
        float eacharc = 1.0f * (float)(end - start) / (float)allArcs;
        writer.flush();
        writer.close();
        BufferedReader gReader = new BufferedReader(new FileReader(inputFile));
        BufferedReader pReader = new BufferedReader(new FileReader(outputFile + ".tmp"));
        BufferedWriter pwriter = new BufferedWriter(new FileWriter(outputFile));
        while ((line = pReader.readLine()) != null) {
            String gLine = gReader.readLine();
            if (line.trim().length() > 0) {
                while (gLine.trim().length() == 0) {
                    gLine = gReader.readLine();
                }
                String[] ps = line.split("\t");
                String[] gs = gLine.split("\t");
                gs[6] = ps[0];
                gs[7] = ps[1];
                StringBuilder output = new StringBuilder();
                for (int i = 0; i < gs.length; ++i) {
                    output.append(gs[i]).append("\t");
                }
                pwriter.write(output.toString().trim() + "\n");
                continue;
            }
            pwriter.write("\n");
        }
        pwriter.flush();
        pwriter.close();
        if (addScore) {
            BufferedWriter scoreWriter = new BufferedWriter(new FileWriter(scorePath));
            for (int i = 0; i < scoreList.size(); ++i) {
                scoreWriter.write(scoreList.get(i) + "\n");
            }
            scoreWriter.flush();
            scoreWriter.close();
        }
        IOUtil.deleteFile(outputFile + ".tmp");
    }

    private void writeParsedSentence(BufferedWriter writer, boolean rootFirst, Configuration bestParse, int[] words) throws IOException {
        StringBuilder finalOutput = new StringBuilder();
        for (int i = 0; i < words.length; ++i) {
            int w = i + 1;
            int head = bestParse.state.getHead(w);
            int dep = bestParse.state.getDependent(w);
            if (w == bestParse.state.rootIndex && !rootFirst) continue;
            if (head == bestParse.state.rootIndex) {
                head = 0;
            }
            String label = head == 0 ? this.maps.rootString : this.maps.idWord[dep];
            String output = head + "\t" + label + "\n";
            finalOutput.append(output);
        }
        finalOutput.append("\n");
        writer.write(finalOutput.toString());
    }

    public void parseTaggedFile(String inputFile, String outputFile, boolean rootFirst, int beamWidth, boolean lowerCased, String separator, int numOfThreads) throws Exception {
        Pair result;
        int i;
        String[] outs;
        String line;
        BufferedReader reader = new BufferedReader(new FileReader(inputFile));
        BufferedWriter writer = new BufferedWriter(new FileWriter(outputFile));
        long start = System.currentTimeMillis();
        ExecutorService executor = Executors.newFixedThreadPool(numOfThreads);
        ExecutorCompletionService<Pair<String, Integer>> pool = new ExecutorCompletionService<Pair<String, Integer>>(executor);
        int count = 0;
        int lineNum = 0;
        while ((line = reader.readLine()) != null) {
            pool.submit(new ParseTaggedThread(lineNum++, line, separator, rootFirst, lowerCased, this.maps, beamWidth, this));
            if (lineNum % 1000 != 0) continue;
            outs = new String[lineNum];
            for (i = 0; i < lineNum; ++i) {
                if (++count % 100 == 0) {
                    System.err.print(count + "...");
                }
                result = (Pair)pool.take().get();
                outs[((Integer)result.second).intValue()] = (String)result.first;
            }
            for (i = 0; i < lineNum; ++i) {
                if (outs[i].length() <= 0) continue;
                writer.write(outs[i]);
            }
            lineNum = 0;
        }
        if (lineNum > 0) {
            outs = new String[lineNum];
            for (i = 0; i < lineNum; ++i) {
                if (++count % 100 == 0) {
                    System.err.print(count + "...");
                }
                result = (Pair)pool.take().get();
                outs[((Integer)result.second).intValue()] = (String)result.first;
            }
            for (i = 0; i < lineNum; ++i) {
                if (outs[i].length() <= 0) continue;
                writer.write(outs[i]);
            }
        }
        long end = System.currentTimeMillis();
        System.out.println("\n" + (end - start) + " ms");
        writer.flush();
        writer.close();
        System.out.println("done!");
    }

    public void parseConllFileParallel(String inputFile, String outputFile, boolean rootFirst, int beamWidth, boolean lowerCased, int numThreads, boolean partial, String scorePath) throws IOException, InterruptedException, ExecutionException {
        String line;
        CoNLLReader reader = new CoNLLReader(inputFile);
        boolean addScore = false;
        if (scorePath.trim().length() > 0) {
            addScore = true;
        }
        ArrayList<Float> scoreList = new ArrayList<Float>();
        ExecutorService executor = Executors.newFixedThreadPool(numThreads);
        ExecutorCompletionService<Pair<Configuration, Integer>> pool = new ExecutorCompletionService<Pair<Configuration, Integer>>(executor);
        long start = System.currentTimeMillis();
        int allArcs = 0;
        int size = 0;
        BufferedWriter writer = new BufferedWriter(new FileWriter(outputFile + ".tmp"));
        int dataCount = 0;
        block0: while (true) {
            ArrayList<Instance> data = reader.readData(15000, true, true, rootFirst, lowerCased, this.maps);
            size += data.size();
            if (data.size() == 0) break;
            int index = 0;
            Configuration[] confs = new Configuration[data.size()];
            for (Instance instance : data) {
                ParseThread thread = new ParseThread(index, this.classifier, this.dependencyRelations, this.featureLength, instance.getSentence(), rootFirst, beamWidth, instance, partial);
                pool.submit(thread);
                ++index;
            }
            for (int i = 0; i < confs.length; ++i) {
                ++dataCount;
                Pair configurationIntegerPair = (Pair)pool.take().get();
                confs[((Integer)configurationIntegerPair.second).intValue()] = (Configuration)configurationIntegerPair.first;
            }
            int j = 0;
            while (true) {
                if (j >= confs.length) continue block0;
                Configuration bestParse = confs[j];
                if (addScore) {
                    scoreList.add(Float.valueOf(bestParse.score / (float)bestParse.sentence.size()));
                }
                int[] words = data.get(j).getSentence().getWords();
                allArcs += words.length - 1;
                this.writeParsedSentence(writer, rootFirst, bestParse, words);
                ++j;
            }
            break;
        }
        long end = System.currentTimeMillis();
        float each = 1.0f * (float)(end - start) / (float)size;
        float eacharc = 1.0f * (float)(end - start) / (float)allArcs;
        writer.flush();
        writer.close();
        BufferedReader gReader = new BufferedReader(new FileReader(inputFile));
        BufferedReader pReader = new BufferedReader(new FileReader(outputFile + ".tmp"));
        BufferedWriter pwriter = new BufferedWriter(new FileWriter(outputFile));
        while ((line = pReader.readLine()) != null) {
            String gLine = gReader.readLine();
            if (line.trim().length() > 0) {
                while (gLine.trim().length() == 0) {
                    gLine = gReader.readLine();
                }
                String[] ps = line.split("\t");
                String[] gs = gLine.split("\t");
                gs[6] = ps[0];
                gs[7] = ps[1];
                StringBuilder output = new StringBuilder();
                for (int i = 0; i < gs.length; ++i) {
                    output.append(gs[i]).append("\t");
                }
                pwriter.write(output.toString().trim() + "\n");
                continue;
            }
            pwriter.write("\n");
        }
        pwriter.flush();
        pwriter.close();
        if (addScore) {
            BufferedWriter scoreWriter = new BufferedWriter(new FileWriter(scorePath));
            for (int i = 0; i < scoreList.size(); ++i) {
                scoreWriter.write(scoreList.get(i) + "\n");
            }
            scoreWriter.flush();
            scoreWriter.close();
        }
        IOUtil.deleteFile(outputFile + ".tmp");
    }

    public void shutDownLiveThreads() {
        boolean isTerminated = this.executor.isTerminated();
        while (!isTerminated) {
            this.executor.shutdownNow();
            isTerminated = this.executor.isTerminated();
        }
    }
}

