/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.object_detection.ssd;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicmodelzoo.BasicModelZoo;
import ai.djl.basicmodelzoo.cv.object_detection.ssd.SingleShotDetection;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.SingleShotDetectionTranslator;
import ai.djl.modality.cv.translator.wrapper.FileTranslatorFactory;
import ai.djl.modality.cv.translator.wrapper.InputStreamTranslatorFactory;
import ai.djl.modality.cv.translator.wrapper.UrlTranslatorFactory;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class SingleShotDetectionModelLoader
extends BaseModelLoader<Image, DetectedObjects> {
    private static final Application APPLICATION = Application.CV.OBJECT_DETECTION;
    private static final String GROUP_ID = "ai.djl.zoo";
    private static final String ARTIFACT_ID = "ssd";
    private static final String VERSION = "0.0.1";

    public SingleShotDetectionModelLoader(Repository repository) {
        super(repository, MRL.model((Application)APPLICATION, (String)GROUP_ID, (String)ARTIFACT_ID), VERSION, (ModelZoo)new BasicModelZoo());
        FactoryImpl factory = new FactoryImpl();
        this.factories.put(new Pair(Image.class, DetectedObjects.class), factory);
        this.factories.put(new Pair(Path.class, DetectedObjects.class), new FileTranslatorFactory((TranslatorFactory)factory));
        this.factories.put(new Pair(URL.class, DetectedObjects.class), new UrlTranslatorFactory((TranslatorFactory)factory));
        this.factories.put(new Pair(InputStream.class, DetectedObjects.class), new InputStreamTranslatorFactory((TranslatorFactory)factory));
    }

    public Application getApplication() {
        return APPLICATION;
    }

    public ZooModel<Image, DetectedObjects> loadModel(Map<String, String> filters, Device device, Progress progress) throws IOException, ModelNotFoundException, MalformedModelException {
        Criteria criteria = Criteria.builder().setTypes(Image.class, DetectedObjects.class).optFilters(filters).optDevice(device).optProgress(progress).build();
        return this.loadModel(criteria);
    }

    private Block customSSDBlock(Map<String, Object> arguments) {
        int numClasses = ((Double)arguments.get("outSize")).intValue();
        int numFeatures = ((Double)arguments.get("numFeatures")).intValue();
        boolean globalPool = (Boolean)arguments.get("globalPool");
        int[] numFilters = ((List)arguments.get("numFilters")).stream().mapToInt(Double::intValue).toArray();
        List ratio = ((List)arguments.get("ratios")).stream().map(Double::floatValue).collect(Collectors.toList());
        List<List<Float>> sizes = ((List)arguments.get("sizes")).stream().map(size -> size.stream().map(Double::floatValue).collect(Collectors.toList())).collect(Collectors.toList());
        ArrayList<List<Float>> ratios = new ArrayList<List<Float>>();
        for (int i = 0; i < 5; ++i) {
            ratios.add(ratio);
        }
        SequentialBlock baseBlock = new SequentialBlock();
        for (int numFilter : numFilters) {
            baseBlock.add((Block)SingleShotDetection.getDownSamplingBlock(numFilter));
        }
        return SingleShotDetection.builder().setNumClasses(numClasses).setNumFeatures(numFeatures).optGlobalPool(globalPool).setRatios(ratios).setSizes(sizes).setBaseNetwork((Block)baseBlock).build();
    }

    protected Model createModel(String name, Device device, Artifact artifact, Map<String, Object> arguments, String engine) {
        Model model = Model.newInstance((String)name, (Device)device, (String)engine);
        model.setBlock(this.customSSDBlock(arguments));
        return model;
    }

    private static final class FactoryImpl
    implements TranslatorFactory<Image, DetectedObjects> {
        private FactoryImpl() {
        }

        public Translator<Image, DetectedObjects> newInstance(Map<String, Object> arguments) {
            return ((SingleShotDetectionTranslator.Builder)((SingleShotDetectionTranslator.Builder)SingleShotDetectionTranslator.builder().addTransform((Transform)new ToTensor())).optThreshold(((Double)arguments.get("threshold")).floatValue())).build();
        }
    }
}

