/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicmodelzoo.BasicModelZoo;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.modality.cv.zoo.ImageClassificationModelLoader;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.repository.Artifact;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.ModelZoo;
import java.util.List;
import java.util.Map;

public class ResNetModelLoader
extends ImageClassificationModelLoader {
    private static final String GROUP_ID = "ai.djl.zoo";
    private static final String ARTIFACT_ID = "resnet";
    private static final String VERSION = "0.0.1";

    public ResNetModelLoader(Repository repository) {
        super(repository, GROUP_ID, ARTIFACT_ID, VERSION, (ModelZoo)new BasicModelZoo());
    }

    protected Model createModel(String name, Device device, Artifact artifact, Map<String, Object> arguments, String engine) {
        Model model = Model.newInstance((String)name, (Device)device, (String)engine);
        model.setBlock(this.resnetBlock(arguments));
        return model;
    }

    private Block resnetBlock(Map<String, Object> arguments) {
        Shape shape = new Shape(((List)arguments.get("imageShape")).stream().mapToLong(Double::longValue).toArray());
        ResNetV1.Builder blockBuilder = ResNetV1.builder().setNumLayers((int)((Double)arguments.get("numLayers")).doubleValue()).setOutSize((long)((Double)arguments.get("outSize")).doubleValue()).setImageShape(shape);
        if (arguments.containsKey("batchNormMomentum")) {
            float batchNormMomentum = (float)((Double)arguments.get("batchNormMomentum")).doubleValue();
            blockBuilder.optBatchNormMomentum(batchNormMomentum);
        }
        return blockBuilder.build();
    }
}

