/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.optimizer.learningrate;

import ai.djl.training.optimizer.learningrate.LearningRateTracker;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FactorTracker
extends LearningRateTracker {
    private static final Logger logger = LoggerFactory.getLogger(FactorTracker.class);
    private int step;
    private float factor;
    private float stopFactorLearningRate;
    private int count;

    public FactorTracker(Builder builder) {
        super(builder);
        this.step = builder.step;
        this.factor = builder.factor;
        this.stopFactorLearningRate = builder.stopFactorLearningRate;
        this.count = 0;
    }

    @Override
    public float getNewLearningRate(int numUpdate) {
        if (numUpdate < this.warmUpSteps) {
            return this.getWarmUpLearningRate(numUpdate);
        }
        while (numUpdate > this.count + this.step) {
            this.count += this.step;
            this.baseLearningRate *= this.factor;
            if (this.baseLearningRate < this.stopFactorLearningRate) {
                this.baseLearningRate = this.stopFactorLearningRate;
                logger.debug("Update[{}]: now learning rate arrived at {}, will not change in the future", (Object)numUpdate, (Object)String.format("%.5e", Float.valueOf(this.baseLearningRate)));
                continue;
            }
            logger.debug("Update[{}]: Change learning rate to {}", (Object)numUpdate, (Object)String.format("%.5e", Float.valueOf(this.baseLearningRate)));
        }
        this.checkLearningRate(this.baseLearningRate);
        return this.baseLearningRate;
    }

    public static final class Builder
    extends LearningRateTracker.LrBaseBuilder<Builder> {
        int step;
        float factor = 1.0f;
        float stopFactorLearningRate = 1.0E-8f;

        Builder() {
        }

        @Override
        protected Builder self() {
            return this;
        }

        public Builder setStep(int step) {
            if (step < 1) {
                throw new IllegalArgumentException("step should be larger or equal to 1");
            }
            this.step = step;
            return this;
        }

        public Builder optFactor(float factor) {
            if (factor > 1.0f) {
                throw new IllegalArgumentException("factor should be no more than 1");
            }
            this.factor = factor;
            return this;
        }

        public Builder optStopFactorLearningRate(float stopFactorLearningRate) {
            this.stopFactorLearningRate = stopFactorLearningRate;
            return this;
        }

        public FactorTracker build() {
            if (this.step == 0) {
                throw new IllegalArgumentException("Step must be set to change learning rate every N steps");
            }
            return new FactorTracker(this);
        }
    }
}

