package org.deeplearning4j.nn.layers.factory;

import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.LayerFactory;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/layers/factory/DefaultLayerFactory.class */
public class DefaultLayerFactory implements LayerFactory {
    protected Class<? extends Layer> layerClazz;

    public DefaultLayerFactory(Class<? extends Layer> cls) {
        this.layerClazz = cls;
    }

    @Override // org.deeplearning4j.nn.api.LayerFactory
    public Layer create(NeuralNetConfiguration neuralNetConfiguration, int i, int i2) {
        return create(neuralNetConfiguration);
    }

    @Override // org.deeplearning4j.nn.api.LayerFactory
    public Layer create(NeuralNetConfiguration neuralNetConfiguration) {
        Layer defaultLayerFactory = getInstance(neuralNetConfiguration);
        defaultLayerFactory.setParamTable(getParams(neuralNetConfiguration));
        defaultLayerFactory.setConfiguration(neuralNetConfiguration);
        return defaultLayerFactory;
    }

    protected Layer getInstance(NeuralNetConfiguration neuralNetConfiguration) {
        try {
            return this.layerClazz.getConstructor(NeuralNetConfiguration.class).newInstance(neuralNetConfiguration);
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    protected Map<String, INDArray> getParams(NeuralNetConfiguration neuralNetConfiguration) {
        ParamInitializer initializer = initializer();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        initializer.init(linkedHashMap, neuralNetConfiguration);
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.api.LayerFactory
    public String layerClazzName() {
        return this.layerClazz.getName();
    }

    @Override // org.deeplearning4j.nn.api.LayerFactory
    public ParamInitializer initializer() {
        return new DefaultParamInitializer();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof DefaultLayerFactory)) {
            return false;
        }
        DefaultLayerFactory defaultLayerFactory = (DefaultLayerFactory) obj;
        return this.layerClazz != null ? this.layerClazz.equals(defaultLayerFactory.layerClazz) : defaultLayerFactory.layerClazz == null;
    }

    public int hashCode() {
        if (this.layerClazz != null) {
            return this.layerClazz.hashCode();
        }
        return 0;
    }
}
