/*
 * Decompiled with CFR 0.152.
 */
package org.anchoranalysis.plugin.onnx.bean.object.segment.decode.instance.maskrcnn;

import ai.onnxruntime.OnnxTensor;
import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import lombok.Generated;
import org.anchoranalysis.bean.annotation.BeanField;
import org.anchoranalysis.bean.annotation.DefaultInstance;
import org.anchoranalysis.core.exception.OperationFailedException;
import org.anchoranalysis.core.functional.FunctionalList;
import org.anchoranalysis.image.bean.interpolator.Interpolator;
import org.anchoranalysis.image.inference.ImageInferenceContext;
import org.anchoranalysis.image.inference.bean.segment.instance.DecodeInstanceSegmentation;
import org.anchoranalysis.image.inference.segment.LabelledWithConfidence;
import org.anchoranalysis.image.inference.segment.MultiScaleObject;
import org.anchoranalysis.plugin.onnx.bean.object.segment.decode.instance.maskrcnn.ExtractMaskHelper;
import org.anchoranalysis.plugin.onnx.bean.object.segment.decode.instance.maskrcnn.ExtractObjectHelper;

public class DecodeMaskRCNN
extends DecodeInstanceSegmentation<OnnxTensor> {
    private static final String TENSOR_LABELS = "6570";
    private static final String TENSOR_SCORES = "6572";
    private static final String TENSOR_BOXES = "6568";
    private static final String TENSOR_MASKS = "6887";
    @BeanField
    private float minConfidence = 0.5f;
    @BeanField
    private float minMaskValue = 0.5f;
    @BeanField
    @DefaultInstance
    private Interpolator interpolator;

    public List<String> expectedOutputs() {
        return Arrays.asList(TENSOR_BOXES, TENSOR_LABELS, TENSOR_SCORES, TENSOR_MASKS);
    }

    public List<LabelledWithConfidence<MultiScaleObject>> decode(List<OnnxTensor> inferenceOutput, ImageInferenceContext context) throws OperationFailedException {
        FloatBuffer scores = inferenceOutput.get(2).getFloatBuffer();
        List<Integer> indices = this.indicesAboveThreshold(scores);
        if (indices.isEmpty()) {
            return new ArrayList<LabelledWithConfidence<MultiScaleObject>>();
        }
        return this.extractObjects(indices, scores, inferenceOutput, context);
    }

    private List<LabelledWithConfidence<MultiScaleObject>> extractObjects(List<Integer> indices, FloatBuffer scores, List<OnnxTensor> inferenceOutput, ImageInferenceContext context) throws OperationFailedException {
        int numberProposals = scores.capacity();
        FloatBuffer masks = inferenceOutput.get(3).getFloatBuffer();
        ExtractMaskHelper.checkMaskBufferSize(masks, numberProposals);
        LongBuffer labels = inferenceOutput.get(1).getLongBuffer();
        FloatBuffer boxes = inferenceOutput.get(0).getFloatBuffer();
        return FunctionalList.mapToListOptional(indices, index -> ExtractObjectHelper.extractAt(index, scores, masks, labels, boxes, this.minMaskValue, context));
    }

    private List<Integer> indicesAboveThreshold(FloatBuffer scores) {
        ArrayList<Integer> indices = new ArrayList<Integer>();
        for (int i = 0; i < scores.capacity(); ++i) {
            if (!(scores.get(i) >= this.minConfidence)) continue;
            indices.add(i);
        }
        return indices;
    }

    @Generated
    public float getMinConfidence() {
        return this.minConfidence;
    }

    @Generated
    public void setMinConfidence(float minConfidence) {
        this.minConfidence = minConfidence;
    }

    @Generated
    public float getMinMaskValue() {
        return this.minMaskValue;
    }

    @Generated
    public void setMinMaskValue(float minMaskValue) {
        this.minMaskValue = minMaskValue;
    }

    @Generated
    public Interpolator getInterpolator() {
        return this.interpolator;
    }

    @Generated
    public void setInterpolator(Interpolator interpolator) {
        this.interpolator = interpolator;
    }
}

