package cn.smartjavaai.ocr.factory;

import cn.smartjavaai.common.config.Config;
import cn.smartjavaai.ocr.config.DirectionModelConfig;
import cn.smartjavaai.ocr.config.OcrDetModelConfig;
import cn.smartjavaai.ocr.config.OcrRecModelConfig;
import cn.smartjavaai.ocr.exception.OcrException;
import cn.smartjavaai.ocr.model.common.detect.OcrCommonDetModel;
import cn.smartjavaai.ocr.model.common.detect.PpOCRV5DetModel;
import cn.smartjavaai.ocr.model.common.direction.OcrDirectionModel;
import cn.smartjavaai.ocr.model.common.direction.PPOCRMobileV2Model;
import cn.smartjavaai.ocr.model.common.recognize.OcrCommonRecModel;
import cn.smartjavaai.ocr.model.common.recognize.PpOCRV5RecModel;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:cn/smartjavaai/ocr/factory/OcrModelFactory.class */
public class OcrModelFactory {
    private static volatile OcrModelFactory instance;
    private static final Logger log = LoggerFactory.getLogger(OcrModelFactory.class);
    private static final ConcurrentHashMap<String, OcrCommonDetModel> commonDetModelMap = new ConcurrentHashMap<>();
    private static final ConcurrentHashMap<String, OcrCommonRecModel> commonRecModelMap = new ConcurrentHashMap<>();
    private static final ConcurrentHashMap<String, OcrDirectionModel> directionModelMap = new ConcurrentHashMap<>();
    private static final Map<String, Class<? extends OcrCommonDetModel>> commonDetRegistry = new ConcurrentHashMap();
    private static final Map<String, Class<? extends OcrCommonRecModel>> commonRecRegistry = new ConcurrentHashMap();
    private static final Map<String, Class<? extends OcrDirectionModel>> directionRegistry = new ConcurrentHashMap();

    public static OcrModelFactory getInstance() {
        if (instance == null) {
            synchronized (OcrModelFactory.class) {
                if (instance == null) {
                    instance = new OcrModelFactory();
                }
            }
        }
        return instance;
    }

    private static void registerCommonDetModel(String str, Class<? extends OcrCommonDetModel> cls) {
        commonDetRegistry.put(str.toLowerCase(), cls);
    }

    private static void registerCommonRecModel(String str, Class<? extends OcrCommonRecModel> cls) {
        commonRecRegistry.put(str.toLowerCase(), cls);
    }

    private static void registerDirectionModel(String str, Class<? extends OcrDirectionModel> cls) {
        directionRegistry.put(str.toLowerCase(), cls);
    }

    public OcrCommonDetModel getDetModel(OcrDetModelConfig ocrDetModelConfig) {
        if (Objects.isNull(ocrDetModelConfig) || Objects.isNull(ocrDetModelConfig.getModelEnum())) {
            throw new OcrException("未配置OCR模型");
        }
        return commonDetModelMap.computeIfAbsent(ocrDetModelConfig.getModelEnum().name(), str -> {
            return createCommonDetModel(ocrDetModelConfig);
        });
    }

    public OcrCommonRecModel getRecModel(OcrRecModelConfig ocrRecModelConfig) {
        if (Objects.isNull(ocrRecModelConfig) || Objects.isNull(ocrRecModelConfig.getRecModelEnum())) {
            throw new OcrException("未配置OCR模型");
        }
        return commonRecModelMap.computeIfAbsent(ocrRecModelConfig.getRecModelEnum().name(), str -> {
            return createCommonRecModel(ocrRecModelConfig);
        });
    }

    public OcrDirectionModel getDirectionModel(DirectionModelConfig directionModelConfig) {
        if (Objects.isNull(directionModelConfig) || Objects.isNull(directionModelConfig.getModelEnum())) {
            throw new OcrException("未配置OCR模型");
        }
        return directionModelMap.computeIfAbsent(directionModelConfig.getModelEnum().name(), str -> {
            return createDirectionModel(directionModelConfig);
        });
    }

    private OcrCommonDetModel createCommonDetModel(OcrDetModelConfig ocrDetModelConfig) {
        Class<? extends OcrCommonDetModel> cls = commonDetRegistry.get(ocrDetModelConfig.getModelEnum().name().toLowerCase());
        if (cls == null) {
            throw new OcrException("Unsupported model");
        }
        try {
            OcrCommonDetModel newInstance = cls.newInstance();
            newInstance.loadModel(ocrDetModelConfig);
            return newInstance;
        } catch (IllegalAccessException | InstantiationException e) {
            throw new OcrException(e);
        }
    }

    private OcrCommonRecModel createCommonRecModel(OcrRecModelConfig ocrRecModelConfig) {
        Class<? extends OcrCommonRecModel> cls = commonRecRegistry.get(ocrRecModelConfig.getRecModelEnum().name().toLowerCase());
        if (cls == null) {
            throw new OcrException("Unsupported model");
        }
        try {
            OcrCommonRecModel newInstance = cls.newInstance();
            newInstance.loadModel(ocrRecModelConfig);
            return newInstance;
        } catch (IllegalAccessException | InstantiationException e) {
            throw new OcrException(e);
        }
    }

    private OcrDirectionModel createDirectionModel(DirectionModelConfig directionModelConfig) {
        Class<? extends OcrDirectionModel> cls = directionRegistry.get(directionModelConfig.getModelEnum().name().toLowerCase());
        if (cls == null) {
            throw new OcrException("Unsupported model");
        }
        try {
            OcrDirectionModel newInstance = cls.newInstance();
            newInstance.loadModel(directionModelConfig);
            return newInstance;
        } catch (IllegalAccessException | InstantiationException e) {
            throw new OcrException(e);
        }
    }

    static {
        registerCommonDetModel("PADDLEOCR_V5_DET_MODEL", PpOCRV5DetModel.class);
        registerCommonRecModel("PADDLEOCR_V5_REC_MODEL", PpOCRV5RecModel.class);
        registerDirectionModel("CH_PPOCR_MOBILE_V2_CLS", PPOCRMobileV2Model.class);
        log.info("缓存目录：{}", Config.getCachePath());
    }
}
