/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2D;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import java.util.Arrays;

public final class ResNetV1 {
    private ResNetV1() {
    }

    public static Block residualUnit(int numFilters, Shape stride, boolean dimMatch, boolean bottleneck, float batchNormMomentum) {
        SequentialBlock resUnit = new SequentialBlock();
        if (bottleneck) {
            resUnit.add((Block)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setKernel(new Shape(new long[]{1L, 1L}))).setNumFilters(numFilters / 4)).optStride(stride)).optPad(new Shape(new long[]{0L, 0L}))).optBias(true)).build()).add((Block)BatchNorm.builder().optEpsilon(1.0E-5f).optMomentum(batchNormMomentum).build()).add(Activation::relu).add((Block)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setKernel(new Shape(new long[]{3L, 3L}))).setNumFilters(numFilters / 4)).optStride(new Shape(new long[]{1L, 1L}))).optPad(new Shape(new long[]{1L, 1L}))).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(batchNormMomentum).build()).add(Activation::relu).add((Block)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setKernel(new Shape(new long[]{1L, 1L}))).setNumFilters(numFilters)).optStride(new Shape(new long[]{1L, 1L}))).optPad(new Shape(new long[]{0L, 0L}))).optBias(true)).build()).add((Block)BatchNorm.builder().optEpsilon(1.0E-5f).optMomentum(batchNormMomentum).build());
        } else {
            resUnit.add((Block)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setKernel(new Shape(new long[]{3L, 3L}))).setNumFilters(numFilters)).optStride(stride)).optPad(new Shape(new long[]{1L, 1L}))).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(1.0E-5f).optMomentum(batchNormMomentum).build()).add(Activation::relu).add((Block)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setKernel(new Shape(new long[]{3L, 3L}))).setNumFilters(numFilters)).optStride(new Shape(new long[]{1L, 1L}))).optPad(new Shape(new long[]{1L, 1L}))).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(1.0E-5f).optMomentum(batchNormMomentum).build());
        }
        SequentialBlock shortcut = new SequentialBlock();
        if (dimMatch) {
            shortcut.add(Blocks.identityBlock());
        } else {
            shortcut.add((Block)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setKernel(new Shape(new long[]{1L, 1L}))).setNumFilters(numFilters)).optStride(stride)).optPad(new Shape(new long[]{0L, 0L}))).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(1.0E-5f).optMomentum(batchNormMomentum).build());
        }
        return new ParallelBlock(list -> {
            NDList unit = (NDList)list.get(0);
            NDList parallel = (NDList)list.get(1);
            return new NDList(new NDArray[]{unit.singletonOrThrow().add(parallel.singletonOrThrow()).getNDArrayInternal().relu()});
        }, Arrays.asList(resUnit, shortcut));
    }

    public static Block resnet(Builder builder) {
        int numStages = builder.units.length;
        long height = builder.imageShape.get(1);
        SequentialBlock resNet = new SequentialBlock();
        if (height <= 32L) {
            resNet.add((Block)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setKernel(new Shape(new long[]{3L, 3L}))).setNumFilters(builder.filters[0])).optStride(new Shape(new long[]{1L, 1L}))).optPad(new Shape(new long[]{1L, 1L}))).optBias(false)).build());
        } else {
            resNet.add((Block)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setKernel(new Shape(new long[]{7L, 7L}))).setNumFilters(builder.filters[0])).optStride(new Shape(new long[]{2L, 2L}))).optPad(new Shape(new long[]{3L, 3L}))).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.reluBlock()).add(Pool.maxPool2DBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{1L, 1L})));
        }
        Shape resStride = new Shape(new long[]{1L, 1L});
        for (int i = 0; i < numStages; ++i) {
            resNet.add(ResNetV1.residualUnit(builder.filters[i + 1], resStride, false, builder.bottleneck, builder.batchNormMomentum));
            for (int j = 0; j < builder.units[i] - 1; ++j) {
                resNet.add(ResNetV1.residualUnit(builder.filters[i + 1], new Shape(new long[]{1L, 1L}), true, builder.bottleneck, builder.batchNormMomentum));
            }
            if (i != 0) continue;
            resStride = new Shape(new long[]{2L, 2L});
        }
        return resNet.add(Pool.globalAvgPool2DBlock()).add(Blocks.batchFlattenBlock()).add((Block)Linear.builder().setOutChannels(builder.outSize).build()).add(Blocks.batchFlattenBlock());
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        int numLayers;
        int numStages;
        long outSize;
        float batchNormMomentum = 0.9f;
        Shape imageShape;
        boolean bottleneck;
        int[] units;
        int[] filters;

        Builder() {
        }

        public Builder setNumLayers(int numLayers) {
            this.numLayers = numLayers;
            return this;
        }

        public Builder setOutSize(long outSize) {
            this.outSize = outSize;
            return this;
        }

        public Builder optBatchNormMomemtum(float batchNormMomemtum) {
            this.batchNormMomentum = batchNormMomemtum;
            return this;
        }

        public Builder setImageShape(Shape imageShape) {
            this.imageShape = imageShape;
            return this;
        }

        public Block build() {
            if (this.imageShape == null) {
                throw new IllegalArgumentException("Must set imageShape");
            }
            long height = this.imageShape.get(1);
            if (height <= 28L) {
                int perUnit;
                this.numStages = 3;
                if ((this.numLayers - 2) % 9 == 0 && this.numLayers >= 164) {
                    perUnit = (this.numLayers - 2) / 9;
                    this.filters = new int[]{16, 64, 128, 256};
                    this.bottleneck = true;
                } else if ((this.numLayers - 2) % 6 == 0 && this.numLayers < 164) {
                    perUnit = (this.numLayers - 2) / 6;
                    this.filters = new int[]{16, 16, 32, 64};
                    this.bottleneck = false;
                } else {
                    throw new IllegalArgumentException("no experiments done on num_layers " + this.numLayers + ", you can do it yourself");
                }
                this.units = new int[this.numStages];
                for (int i = 0; i < this.numStages; ++i) {
                    this.units[i] = perUnit;
                }
            } else {
                this.numStages = 4;
                if (this.numLayers >= 50) {
                    this.filters = new int[]{64, 256, 512, 1024, 2048};
                    this.bottleneck = true;
                } else {
                    this.filters = new int[]{64, 64, 128, 256, 512};
                    this.bottleneck = true;
                }
                if (this.numLayers == 18) {
                    this.units = new int[]{2, 2, 2, 2};
                } else if (this.numLayers == 34) {
                    this.units = new int[]{3, 4, 6, 3};
                } else if (this.numLayers == 50) {
                    this.units = new int[]{3, 4, 6, 3};
                } else if (this.numLayers == 101) {
                    this.units = new int[]{3, 4, 23, 3};
                } else if (this.numLayers == 152) {
                    this.units = new int[]{3, 8, 36, 3};
                } else if (this.numLayers == 200) {
                    this.units = new int[]{3, 24, 36, 3};
                } else if (this.numLayers == 269) {
                    this.units = new int[]{3, 30, 48, 8};
                } else {
                    throw new IllegalArgumentException("no experiments done on num_layers " + this.numLayers + ", you can do it yourself");
                }
            }
            return ResNetV1.resnet(this);
        }
    }
}

