package org.deeplearning4j.optimize.solvers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.GradientAdjustment;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.deeplearning4j.optimize.terminations.EpsTermination;
import org.deeplearning4j.optimize.terminations.ZeroDirection;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaGrad;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/BaseOptimizer.class */
public abstract class BaseOptimizer implements ConvexOptimizer {
    protected NeuralNetConfiguration conf;
    protected AdaGrad adaGrad;
    protected int iteration;
    protected static final Logger log = LoggerFactory.getLogger(BaseOptimizer.class);
    protected StepFunction stepFunction;
    protected Collection<IterationListener> iterationListeners;
    protected Collection<TerminationCondition> terminationConditions;
    protected Model model;
    protected BackTrackLineSearch lineMaximizer;
    protected double step;
    private int batchSize;
    protected double score;
    protected double oldScore;
    protected double stpMax;
    public static final String GRADIENT_KEY = "g";
    public static final String SCORE_KEY = "score";
    public static final String PARAMS_KEY = "params";
    protected Map<String, AdaGrad> adaGradForVariable;
    protected Map<String, Object> searchState;

    public BaseOptimizer(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<IterationListener> collection, Model model) {
        this(neuralNetConfiguration, stepFunction, collection, Arrays.asList(new ZeroDirection(), new EpsTermination()), model);
    }

    public BaseOptimizer(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<IterationListener> collection, Collection<TerminationCondition> collection2, Model model) {
        this.iteration = 0;
        this.iterationListeners = new ArrayList();
        this.terminationConditions = new ArrayList();
        this.batchSize = 10;
        this.stpMax = Double.MAX_VALUE;
        this.adaGradForVariable = new ConcurrentHashMap();
        this.searchState = new ConcurrentHashMap();
        this.conf = neuralNetConfiguration;
        this.stepFunction = stepFunction;
        this.iterationListeners = collection != null ? collection : new ArrayList<>();
        this.terminationConditions = collection2;
        this.model = model;
        this.lineMaximizer = new BackTrackLineSearch(model, stepFunction, this);
        this.lineMaximizer.setStpmax(this.stpMax);
        this.lineMaximizer.setMaxIterations(neuralNetConfiguration.getNumLineSearchIterations());
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void updateGradientAccordingToParams(INDArray iNDArray, INDArray iNDArray2, int i) {
        GradientAdjustment.updateGradientAccordingToParams(this.conf, 0, this.adaGrad, iNDArray, iNDArray2, i);
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public double score() {
        this.model.setScore();
        return this.model.score();
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public Pair<Gradient, Double> gradientAndScore() {
        this.model.setScore();
        Pair<Gradient, Double> gradientAndScore = this.model.gradientAndScore();
        updateGradientAccordingToParams(gradientAndScore.getFirst(), this.model, this.model.batchSize());
        return gradientAndScore;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public boolean optimize() {
        this.model.validateInput();
        Pair<Gradient, Double> gradientAndScore = gradientAndScore();
        setupSearchState(gradientAndScore);
        this.score = gradientAndScore.getSecond().doubleValue();
        INDArray iNDArray = (INDArray) this.searchState.get(GRADIENT_KEY);
        for (TerminationCondition terminationCondition : this.terminationConditions) {
            if (terminationCondition.terminate(0.0d, 0.0d, new Object[]{iNDArray})) {
                log.info("Hit termination condition " + terminationCondition.getClass().getName());
                return true;
            }
        }
        if (preFirstStepProcess(iNDArray)) {
            try {
                this.step = this.lineMaximizer.optimize(this.step, (INDArray) this.searchState.get(PARAMS_KEY), iNDArray);
            } catch (InvalidStepException e) {
                log.warn("Invalid step...continuing another iteration");
            }
            iNDArray = (INDArray) this.searchState.get(GRADIENT_KEY);
            postFirstStep(iNDArray);
            if (this.step == 0.0d) {
                log.warn("Unable to step in direction");
                return false;
            }
        }
        for (int i = 0; i < this.conf.getNumIterations(); i++) {
            preProcessLine(iNDArray);
            try {
                this.step = this.lineMaximizer.optimize(this.step, (INDArray) this.searchState.get(PARAMS_KEY), iNDArray);
            } catch (InvalidStepException e2) {
                log.warn("Invalid step...continuing another iteration");
            }
            Iterator<IterationListener> it = this.iterationListeners.iterator();
            while (it.hasNext()) {
                it.next().iterationDone(this.model, i);
            }
            this.oldScore = this.score;
            setupSearchState(gradientAndScore());
            Iterator<TerminationCondition> it2 = this.terminationConditions.iterator();
            while (it2.hasNext()) {
                if (it2.next().terminate(this.score, this.oldScore, new Object[]{iNDArray})) {
                    return true;
                }
            }
            postStep();
            Iterator<TerminationCondition> it3 = this.terminationConditions.iterator();
            while (it3.hasNext()) {
                if (it3.next().terminate(this.score, this.oldScore, new Object[]{iNDArray})) {
                    return true;
                }
            }
        }
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void postFirstStep(INDArray iNDArray) {
    }

    protected boolean preFirstStepProcess(INDArray iNDArray) {
        return false;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public int batchSize() {
        return this.batchSize;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setBatchSize(int i) {
        this.batchSize = i;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void preProcessLine(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void postStep() {
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public AdaGrad getAdaGrad() {
        return this.adaGrad;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public Map<String, AdaGrad> adaGradForVariables() {
        return this.adaGradForVariable;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public AdaGrad getAdaGradForVariable(String str) {
        return this.adaGradForVariable.get(str);
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void updateGradientAccordingToParams(Gradient gradient, Model model, int i) {
        GradientAdjustment.updateGradientAccordingToParams(this.conf, this.iteration, gradient, i, this.adaGradForVariable, model);
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setupSearchState(Pair<Gradient, Double> pair) {
        INDArray gradient = pair.getFirst().gradient(this.conf.variables());
        INDArray params = this.model.params();
        this.searchState.put(GRADIENT_KEY, gradient);
        this.searchState.put(SCORE_KEY, pair.getSecond());
        this.searchState.put(PARAMS_KEY, params);
        this.score = pair.getSecond().doubleValue();
    }
}
