package cn.smartjavaai.ocr.model;

import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDList;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import cn.smartjavaai.common.entity.DetectionResponse;
import cn.smartjavaai.common.pool.PredictorFactory;
import cn.smartjavaai.common.utils.FileUtils;
import cn.smartjavaai.ocr.AbstractOcrModel;
import cn.smartjavaai.ocr.OcrModelConfig;
import cn.smartjavaai.ocr.exception.OcrException;
import cn.smartjavaai.ocr.translator.PaddleOCRV4DetectionTranslator;
import cn.smartjavaai.ocr.utils.ImageUtils;
import cn.smartjavaai.ocr.utils.OcrUtils;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.pool2.ObjectPool;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.opencv.core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:cn/smartjavaai/ocr/model/PaddleOCRV4DetectModel.class */
public class PaddleOCRV4DetectModel extends AbstractOcrModel {
    private static final Logger log = LoggerFactory.getLogger(PaddleOCRV4DetectModel.class);
    private ZooModel detectionModel;
    private ObjectPool<Predictor<Image, NDList>> predictorPool;

    @Override // cn.smartjavaai.ocr.AbstractOcrModel, cn.smartjavaai.ocr.OcrModel
    public void loadModel(OcrModelConfig ocrModelConfig) {
        if (StringUtils.isBlank(ocrModelConfig.getModelPath())) {
            throw new OcrException("modelPath is null");
        }
        try {
            this.detectionModel = ModelZoo.loadModel(Criteria.builder().optEngine("OnnxRuntime").setTypes(Image.class, NDList.class).optModelPath(Paths.get(ocrModelConfig.getModelPath(), new String[0])).optTranslator(new PaddleOCRV4DetectionTranslator(new ConcurrentHashMap())).optProgress(new ProgressBar()).build());
            this.predictorPool = new GenericObjectPool(new PredictorFactory(this.detectionModel));
            log.info("当前设备: " + this.detectionModel.getNDManager().getDevice());
        } catch (IOException | ModelNotFoundException | MalformedModelException e) {
            throw new OcrException("模型加载失败", e);
        }
    }

    @Override // cn.smartjavaai.ocr.AbstractOcrModel, cn.smartjavaai.ocr.OcrModel
    public DetectionResponse detect(String str) {
        if (!FileUtils.isFileExists(str)) {
            throw new OcrException("图像文件不存在");
        }
        try {
            return detect(ImageFactory.getInstance().fromFile(Paths.get(str, new String[0])));
        } catch (IOException e) {
            throw new OcrException("无效的图片", e);
        }
    }

    private DetectionResponse detect(Image image) {
        Predictor predictor = null;
        try {
            try {
                predictor = (Predictor) this.predictorPool.borrowObject();
                DetectionResponse convertToDetectionResponse = OcrUtils.convertToDetectionResponse((NDList) predictor.predict(image), image);
                if (predictor != null) {
                    try {
                        this.predictorPool.returnObject(predictor);
                    } catch (Exception e) {
                        log.warn("归还Predictor失败", e);
                        try {
                            predictor.close();
                        } catch (Exception e2) {
                            log.error("关闭Predictor失败", e2);
                        }
                    }
                }
                return convertToDetectionResponse;
            } catch (Throwable th) {
                if (predictor != null) {
                    try {
                        this.predictorPool.returnObject(predictor);
                    } catch (Exception e3) {
                        log.warn("归还Predictor失败", e3);
                        try {
                            predictor.close();
                        } catch (Exception e4) {
                            log.error("关闭Predictor失败", e4);
                        }
                    }
                }
                throw th;
            }
        } catch (Exception e5) {
            throw new OcrException("OCR检测错误", e5);
        }
    }

    @Override // cn.smartjavaai.ocr.AbstractOcrModel, cn.smartjavaai.ocr.OcrModel
    public void detectAndDraw(String str, String str2) {
        if (!FileUtils.isFileExists(str)) {
            throw new OcrException("图像文件不存在");
        }
        try {
            Image fromFile = ImageFactory.getInstance().fromFile(Paths.get(str, new String[0]));
            DetectionResponse detect = detect(fromFile);
            if (Objects.isNull(detect) || Objects.isNull(detect.getRectangleList()) || detect.getRectangleList().isEmpty()) {
                throw new OcrException("未识别到文字");
            }
            ImageUtils.drawRect((Mat) fromFile.getWrappedImage(), detect);
            Path path = Paths.get(str2, new String[0]);
            log.info("Saving to {}", path.toAbsolutePath().toString());
            fromFile.save(Files.newOutputStream(path, new OpenOption[0]), "png");
        } catch (IOException e) {
            throw new OcrException(e);
        }
    }
}
