/*
 * Decompiled with CFR 0.152.
 */
package org.anchoranalysis.plugin.onnx.bean.object.segment.stack;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtLoggingLevel;
import ai.onnxruntime.OrtProvider;
import ai.onnxruntime.OrtSession;
import java.io.IOException;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Optional;
import org.anchoranalysis.bean.annotation.BeanField;
import org.anchoranalysis.core.exception.CreateException;
import org.anchoranalysis.core.exception.InitializeException;
import org.anchoranalysis.core.exception.OperationFailedException;
import org.anchoranalysis.core.exception.friendly.AnchorImpossibleSituationException;
import org.anchoranalysis.core.log.Logger;
import org.anchoranalysis.image.core.channel.Channel;
import org.anchoranalysis.image.core.dimensions.Dimensions;
import org.anchoranalysis.image.core.dimensions.IncorrectImageSizeException;
import org.anchoranalysis.image.core.stack.Stack;
import org.anchoranalysis.image.inference.bean.segment.instance.SegmentStackIntoObjectsScaleDecode;
import org.anchoranalysis.inference.InferenceModel;
import org.anchoranalysis.inference.concurrency.ConcurrencyPlan;
import org.anchoranalysis.inference.concurrency.ConcurrentModel;
import org.anchoranalysis.inference.concurrency.ConcurrentModelPool;
import org.anchoranalysis.inference.concurrency.CreateModelFailedException;
import org.anchoranalysis.plugin.onnx.bean.object.segment.stack.BufferFromStack;
import org.anchoranalysis.plugin.onnx.model.OnnxModel;
import org.apache.commons.io.IOUtils;

public class SegmentObjectsFromONNXModel
extends SegmentStackIntoObjectsScaleDecode<OnnxTensor, OnnxModel> {
    @BeanField
    private String modelPath;
    @BeanField
    private boolean readFromResources = false;
    @BeanField
    private String inputName;
    @BeanField
    private boolean includeBatchDimension = false;
    @BeanField
    private boolean interleaveChannels = false;
    private byte[] modelAsBytes;

    public ConcurrentModelPool<OnnxModel> createModelPool(ConcurrencyPlan plan, Logger logger) throws CreateModelFailedException {
        return new ConcurrentModelPool(plan, this::readPrepareModel, logger);
    }

    protected OnnxTensor deriveInput(Stack stack, Optional<double[]> subtractMeans) throws OperationFailedException {
        stack = SegmentObjectsFromONNXModel.convertToBGR(stack);
        FloatBuffer bufferTensor = BufferFromStack.createFrom(stack, subtractMeans, this.interleaveChannels);
        bufferTensor.rewind();
        try {
            return OnnxTensor.createTensor((OrtEnvironment)OrtEnvironment.getEnvironment(), (FloatBuffer)bufferTensor, (long[])this.deriveShape(stack));
        }
        catch (OrtException e) {
            throw new OperationFailedException((Throwable)e);
        }
    }

    protected Optional<String> inputName() {
        return Optional.of(this.getInputName());
    }

    private boolean configureCUDAIfPossible(OrtSession.SessionOptions options) {
        if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.CUDA)) {
            try {
                options.addCUDA();
                return true;
            }
            catch (OrtException e) {
                return false;
            }
        }
        return false;
    }

    private Optional<ConcurrentModel<OnnxModel>> readPrepareModel(boolean useGPU) throws CreateModelFailedException {
        try {
            OrtEnvironment env = OrtEnvironment.getEnvironment((OrtLoggingLevel)OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL);
            OrtSession.SessionOptions options = new OrtSession.SessionOptions();
            if (useGPU && !this.configureCUDAIfPossible(options)) {
                options.close();
                return Optional.empty();
            }
            OrtSession session = env.createSession(this.readModelIfNecessary(), options);
            return Optional.of(new ConcurrentModel((InferenceModel)new OnnxModel(session), useGPU));
        }
        catch (OrtException | IOException | InitializeException e) {
            throw new CreateModelFailedException(e);
        }
    }

    private long[] deriveShape(Stack stack) {
        Dimensions dimensions = stack.getChannel(0).dimensions();
        if (this.includeBatchDimension) {
            return new long[]{1L, dimensions.y(), dimensions.x(), stack.getNumberChannels()};
        }
        return new long[]{stack.getNumberChannels(), dimensions.y(), dimensions.x()};
    }

    private byte[] readModelIfNecessary() throws IOException, InitializeException {
        if (this.modelAsBytes == null) {
            if (this.readFromResources) {
                ClassLoader classloader = Thread.currentThread().getContextClassLoader();
                InputStream inputStream = classloader.getResourceAsStream(this.modelPath);
                this.modelAsBytes = IOUtils.toByteArray((InputStream)inputStream);
            } else {
                Path path = this.resolve(this.modelPath);
                this.modelAsBytes = Files.readAllBytes(path);
            }
        }
        return this.modelAsBytes;
    }

    private static Stack convertToBGR(Stack stack) {
        try {
            return new Stack(false, new Channel[]{stack.getChannel(2), stack.getChannel(1), stack.getChannel(0)});
        }
        catch (CreateException | IncorrectImageSizeException e1) {
            throw new AnchorImpossibleSituationException();
        }
    }

    public String getModelPath() {
        return this.modelPath;
    }

    public void setModelPath(String modelPath) {
        this.modelPath = modelPath;
    }

    public boolean isReadFromResources() {
        return this.readFromResources;
    }

    public void setReadFromResources(boolean readFromResources) {
        this.readFromResources = readFromResources;
    }

    public String getInputName() {
        return this.inputName;
    }

    public void setInputName(String inputName) {
        this.inputName = inputName;
    }

    public boolean isIncludeBatchDimension() {
        return this.includeBatchDimension;
    }

    public void setIncludeBatchDimension(boolean includeBatchDimension) {
        this.includeBatchDimension = includeBatchDimension;
    }

    public boolean isInterleaveChannels() {
        return this.interleaveChannels;
    }

    public void setInterleaveChannels(boolean interleaveChannels) {
        this.interleaveChannels = interleaveChannels;
    }
}

