/*
 * Decompiled with CFR 0.152.
 */
package org.anchoranalysis.plugin.onnx.bean.object.segment.decode.instance.text;

import ai.onnxruntime.OnnxTensor;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.anchoranalysis.bean.annotation.BeanField;
import org.anchoranalysis.core.exception.OperationFailedException;
import org.anchoranalysis.core.time.ExecutionTimeRecorder;
import org.anchoranalysis.image.inference.ImageInferenceContext;
import org.anchoranalysis.image.inference.bean.segment.instance.DecodeInstanceSegmentation;
import org.anchoranalysis.image.inference.segment.DualScale;
import org.anchoranalysis.image.inference.segment.LabelledWithConfidence;
import org.anchoranalysis.image.inference.segment.MultiScaleObject;
import org.anchoranalysis.image.voxel.object.ObjectMask;
import org.anchoranalysis.mpp.mark.Mark;
import org.anchoranalysis.mpp.mark.MarkToObjectConverter;
import org.anchoranalysis.mpp.mark.points.RotatableBoundingBox;
import org.anchoranalysis.mpp.mark.points.RotatableBoundingBoxFactory;
import org.anchoranalysis.spatial.point.Point2i;
import org.anchoranalysis.spatial.scale.ScaleFactorInt;

public class DecodeEAST
extends DecodeInstanceSegmentation<OnnxTensor> {
    private static final String OUTPUT_SCORES = "feature_fusion/Conv_7/Sigmoid:0";
    private static final String OUTPUT_GEOMETRY = "feature_fusion/concat_3:0";
    private static final String CLASS_LABEL = "text";
    private static final ScaleFactorInt SCALE_BY_4 = new ScaleFactorInt(4, 4);
    private static final int VECTOR_SIZE = 5;
    @BeanField
    private double minConfidence = 0.5;

    public List<LabelledWithConfidence<MultiScaleObject>> decode(List<OnnxTensor> inferenceOutput, ImageInferenceContext context) throws OperationFailedException {
        FloatBuffer scores = inferenceOutput.get(0).getFloatBuffer();
        List<Integer> indices = this.indicesAboveThreshold(scores);
        return DecodeEAST.extractObjects(inferenceOutput.get(1), scores, indices, DecodeEAST.dualScaleConverters(context), context.getExecutionTimeRecorder());
    }

    public List<String> expectedOutputs() {
        return Arrays.asList(OUTPUT_SCORES, OUTPUT_GEOMETRY);
    }

    private List<Integer> indicesAboveThreshold(FloatBuffer scores) {
        scores.rewind();
        ArrayList<Integer> indices = new ArrayList<Integer>(scores.capacity());
        for (int i = 0; i < scores.capacity(); ++i) {
            if (!((double)scores.get() >= this.minConfidence)) continue;
            indices.add(i);
        }
        return indices;
    }

    private static List<LabelledWithConfidence<MultiScaleObject>> extractObjects(OnnxTensor geometryTensor, FloatBuffer scores, List<Integer> indices, DualScale<MarkToObjectConverter> converter, ExecutionTimeRecorder executionTimeRecorder) {
        ArrayList<LabelledWithConfidence<MultiScaleObject>> out = new ArrayList<LabelledWithConfidence<MultiScaleObject>>(indices.size());
        int height = (int)geometryTensor.getInfo().getShape()[2];
        FloatBuffer geometryBuffer = geometryTensor.getFloatBuffer();
        for (int index : indices) {
            int x = index % height;
            int y = index / height;
            Point2i anchorPointScaled = SCALE_BY_4.scale(x, y);
            out.add(DecodeEAST.extractLabelledBoundingBox(scores, geometryBuffer, index, anchorPointScaled, converter, executionTimeRecorder));
        }
        return out;
    }

    private static LabelledWithConfidence<MultiScaleObject> extractLabelledBoundingBox(FloatBuffer scores, FloatBuffer geometry, int index, Point2i offset, DualScale<MarkToObjectConverter> convertersDual, ExecutionTimeRecorder executionTimeRecorder) {
        MultiScaleObject objectAtScale = MultiScaleObject.extractFrom(convertersDual, converter -> DecodeEAST.createObjectFromGeometry(index, offset, geometry, converter, executionTimeRecorder));
        return new LabelledWithConfidence((Object)objectAtScale, (double)scores.get(index), CLASS_LABEL);
    }

    private static ObjectMask createObjectFromGeometry(int index, Point2i offset, FloatBuffer geometry, MarkToObjectConverter converter, ExecutionTimeRecorder executionTimeRecorder) {
        int indexStart = index * 5;
        RotatableBoundingBox mark = RotatableBoundingBoxFactory.create(vectorIndex -> geometry.get(indexStart + vectorIndex), (Point2i)offset);
        return (ObjectMask)executionTimeRecorder.recordExecutionTime("Convert mark", () -> DecodeEAST.lambda$createObjectFromGeometry$2(converter, (Mark)mark));
    }

    private static DualScale<MarkToObjectConverter> dualScaleConverters(ImageInferenceContext context) {
        return context.scaleFactorUpscale().combine(context.getDimensions(), MarkToObjectConverter::new);
    }

    public double getMinConfidence() {
        return this.minConfidence;
    }

    public void setMinConfidence(double minConfidence) {
        this.minConfidence = minConfidence;
    }

    private static /* synthetic */ ObjectMask lambda$createObjectFromGeometry$2(MarkToObjectConverter converter, Mark mark) throws RuntimeException {
        return converter.convert(mark);
    }
}

