/*
 * Decompiled with CFR 0.152.
 */
package ai.konduit.serving.models.samediff.step.trainer;

import ai.konduit.serving.annotation.json.JsonName;
import ai.konduit.serving.pipeline.api.step.PipelineStep;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonName(value="SAMEDIFF_TRAINING")
@Schema(description="A pipeline step that configures a SameDiff model that is to be executed.")
public class SameDiffTrainerStep
implements PipelineStep {
    @Schema(description="Specifies the location of a saved model file.")
    private String modelUri;
    @Schema(description="An L1 regularization coefficient for application during training.  Set this value for l1 regularization. Not applied by default.")
    private double l1 = -1.0;
    @Schema(description="An L2 regularization coefficient for application during training. Set this value for l2 regularization. Not applied by default.")
    private double l2 = -1.0;
    @Schema(description="A weight regularization coefficient for application during training. Set this value to enable weight decay. Disabled byd efault.")
    private double weightDecayCoefficient;
    @Schema(description="Whether to apply learning rate during weight decay,defaults to true")
    private boolean weightDecayApplyLearningRate = true;
    @Schema(description="Specifies the location of the model save path")
    private String modelSaveOutputPath;
    @Schema(description="Specifies the number of epochs to run training for")
    private int numEpochs = 1;
    @Schema(description="A list of names of the loss variables- the names of the targets to train against for the loss function")
    private List<String> lossVariables;
    @Schema(description="A list of names of the input variables- the names of the input variables for training")
    private List<String> inputFeatures;
    @Schema(description="A list of names of the labels variables- the names of the true labels for prediction to calculate error against")
    private List<String> labels;
    @Schema(description="A list of names of the prediction variables- the names of the prediction labels for prediction to calculate error against")
    private List<String> targetVariables;
    @Schema(description="The updater to use for training. When specifying an updater on the command line, the type is needed. Valid types include:  AMSGRAD,ADABELIEF,ADAGRAD,ADADELTA,ADAMAX,ADAM,NADAM,NESTEROVS,NOOP,RMSPROP,SGD . Each field for the updater must be specified in terms of field name = value separated by commas. Relevant updaters and their fields can be found here: https://github.com/eclipse/deeplearning4j/tree/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config")
    private IUpdater updater;
    @Schema(description="The learning rate to use for training")
    private double learningRate;
    @Schema(description="The learning rate schedule to use for training. When specifying a learning rate or momentum schedule, comma separated values with key=value for each field is required. Valid values include: poly,step,cycle,fixed,inverse,sigmoid,exponential. Relevant schedules and their fields can be found here: https://github.com/eclipse/deeplearning4j/tree/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/schedule - it is recommended when specifying this value on the command line to use \" to ensure the value gets parsed properly.")
    private ISchedule learningRateSchedule;
    @Schema(description="The initial loss type for the training, defaults to float")
    private DataType initialLossType = DataType.FLOAT;
    @Schema(description="The loss function to use for training models")
    private LossFunctions.LossFunction lossFunction;
    @Schema(description="Enable debug mode, defaults to false")
    private boolean debugMode = false;
    @Schema(description="Enable verbose mode, defaults to false")
    private boolean verboseMode = false;

    public SameDiffTrainerStep(@JsonProperty(value="modelUri") String modelUri, @JsonProperty(value="l1") double l1, @JsonProperty(value="l2") double l2, @JsonProperty(value="modelSaveOutputPath") String modelSaveOutputPath, @JsonProperty(value="numEpochs") int numEpochs, @JsonProperty(value="inputFeatures") List<String> inputFeatures, @JsonProperty(value="lossVariables") List<String> lossVariables, @JsonProperty(value="labels") List<String> labels, @JsonProperty(value="targetVariables") List<String> targetVariables, @JsonProperty(value="weightDecayCoefficient") double weightDecayCoefficient, @JsonProperty(value="weightDecayApplyLearningRate") boolean weightDecayApplyLearningRate, @JsonProperty(value="updater") IUpdater updater, @JsonProperty(value="learningRate") double learningRate, @JsonProperty(value="learningRateSchedule") ISchedule learningRateSchedule, @JsonProperty(value="initialLossType") DataType initialLossType, @JsonProperty(value="lossFunction") LossFunctions.LossFunction lossFunction, @JsonProperty(value="debugMode") boolean debugMode, @JsonProperty(value="verboseMode") boolean verboseMode) {
        this.modelUri = modelUri;
        this.l1 = l1;
        this.l2 = l2;
        this.modelSaveOutputPath = modelSaveOutputPath;
        this.numEpochs = numEpochs;
        this.lossVariables = lossVariables;
        this.inputFeatures = inputFeatures;
        this.targetVariables = targetVariables;
        this.labels = labels;
        this.weightDecayApplyLearningRate = weightDecayApplyLearningRate;
        this.weightDecayCoefficient = weightDecayCoefficient;
        this.learningRate = learningRate;
        this.learningRateSchedule = learningRateSchedule;
        this.updater = updater;
        this.lossFunction = lossFunction;
        if (initialLossType != null) {
            this.initialLossType = initialLossType;
        }
        if (learningRate > 0.0 && learningRateSchedule != null) {
            this.updater.setLrAndSchedule(learningRate, learningRateSchedule);
        }
        this.debugMode = debugMode;
        this.verboseMode = verboseMode;
    }

    protected SameDiffTrainerStep(SameDiffTrainerStepBuilder<?, ?> b) {
        this.modelUri = ((SameDiffTrainerStepBuilder)b).modelUri;
        this.l1 = ((SameDiffTrainerStepBuilder)b).l1;
        this.l2 = ((SameDiffTrainerStepBuilder)b).l2;
        this.weightDecayCoefficient = ((SameDiffTrainerStepBuilder)b).weightDecayCoefficient;
        this.weightDecayApplyLearningRate = ((SameDiffTrainerStepBuilder)b).weightDecayApplyLearningRate;
        this.modelSaveOutputPath = ((SameDiffTrainerStepBuilder)b).modelSaveOutputPath;
        this.numEpochs = ((SameDiffTrainerStepBuilder)b).numEpochs;
        this.lossVariables = ((SameDiffTrainerStepBuilder)b).lossVariables;
        this.inputFeatures = ((SameDiffTrainerStepBuilder)b).inputFeatures;
        this.labels = ((SameDiffTrainerStepBuilder)b).labels;
        this.targetVariables = ((SameDiffTrainerStepBuilder)b).targetVariables;
        this.updater = ((SameDiffTrainerStepBuilder)b).updater;
        this.learningRate = ((SameDiffTrainerStepBuilder)b).learningRate;
        this.learningRateSchedule = ((SameDiffTrainerStepBuilder)b).learningRateSchedule;
        this.initialLossType = ((SameDiffTrainerStepBuilder)b).initialLossType;
        this.lossFunction = ((SameDiffTrainerStepBuilder)b).lossFunction;
        this.debugMode = ((SameDiffTrainerStepBuilder)b).debugMode;
        this.verboseMode = ((SameDiffTrainerStepBuilder)b).verboseMode;
    }

    public static SameDiffTrainerStepBuilder<?, ?> builder() {
        return new SameDiffTrainerStepBuilderImpl();
    }

    public String modelUri() {
        return this.modelUri;
    }

    public double l1() {
        return this.l1;
    }

    public double l2() {
        return this.l2;
    }

    public double weightDecayCoefficient() {
        return this.weightDecayCoefficient;
    }

    public boolean weightDecayApplyLearningRate() {
        return this.weightDecayApplyLearningRate;
    }

    public String modelSaveOutputPath() {
        return this.modelSaveOutputPath;
    }

    public int numEpochs() {
        return this.numEpochs;
    }

    public List<String> lossVariables() {
        return this.lossVariables;
    }

    public List<String> inputFeatures() {
        return this.inputFeatures;
    }

    public List<String> labels() {
        return this.labels;
    }

    public List<String> targetVariables() {
        return this.targetVariables;
    }

    public IUpdater updater() {
        return this.updater;
    }

    public double learningRate() {
        return this.learningRate;
    }

    public ISchedule learningRateSchedule() {
        return this.learningRateSchedule;
    }

    public DataType initialLossType() {
        return this.initialLossType;
    }

    public LossFunctions.LossFunction lossFunction() {
        return this.lossFunction;
    }

    public boolean debugMode() {
        return this.debugMode;
    }

    public boolean verboseMode() {
        return this.verboseMode;
    }

    public SameDiffTrainerStep modelUri(String modelUri) {
        this.modelUri = modelUri;
        return this;
    }

    public SameDiffTrainerStep l1(double l1) {
        this.l1 = l1;
        return this;
    }

    public SameDiffTrainerStep l2(double l2) {
        this.l2 = l2;
        return this;
    }

    public SameDiffTrainerStep weightDecayCoefficient(double weightDecayCoefficient) {
        this.weightDecayCoefficient = weightDecayCoefficient;
        return this;
    }

    public SameDiffTrainerStep weightDecayApplyLearningRate(boolean weightDecayApplyLearningRate) {
        this.weightDecayApplyLearningRate = weightDecayApplyLearningRate;
        return this;
    }

    public SameDiffTrainerStep modelSaveOutputPath(String modelSaveOutputPath) {
        this.modelSaveOutputPath = modelSaveOutputPath;
        return this;
    }

    public SameDiffTrainerStep numEpochs(int numEpochs) {
        this.numEpochs = numEpochs;
        return this;
    }

    public SameDiffTrainerStep lossVariables(List<String> lossVariables) {
        this.lossVariables = lossVariables;
        return this;
    }

    public SameDiffTrainerStep inputFeatures(List<String> inputFeatures) {
        this.inputFeatures = inputFeatures;
        return this;
    }

    public SameDiffTrainerStep labels(List<String> labels) {
        this.labels = labels;
        return this;
    }

    public SameDiffTrainerStep targetVariables(List<String> targetVariables) {
        this.targetVariables = targetVariables;
        return this;
    }

    public SameDiffTrainerStep updater(IUpdater updater) {
        this.updater = updater;
        return this;
    }

    public SameDiffTrainerStep learningRate(double learningRate) {
        this.learningRate = learningRate;
        return this;
    }

    public SameDiffTrainerStep learningRateSchedule(ISchedule learningRateSchedule) {
        this.learningRateSchedule = learningRateSchedule;
        return this;
    }

    public SameDiffTrainerStep initialLossType(DataType initialLossType) {
        this.initialLossType = initialLossType;
        return this;
    }

    public SameDiffTrainerStep lossFunction(LossFunctions.LossFunction lossFunction) {
        this.lossFunction = lossFunction;
        return this;
    }

    public SameDiffTrainerStep debugMode(boolean debugMode) {
        this.debugMode = debugMode;
        return this;
    }

    public SameDiffTrainerStep verboseMode(boolean verboseMode) {
        this.verboseMode = verboseMode;
        return this;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof SameDiffTrainerStep)) {
            return false;
        }
        SameDiffTrainerStep other = (SameDiffTrainerStep)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (Double.compare(this.l1(), other.l1()) != 0) {
            return false;
        }
        if (Double.compare(this.l2(), other.l2()) != 0) {
            return false;
        }
        if (Double.compare(this.weightDecayCoefficient(), other.weightDecayCoefficient()) != 0) {
            return false;
        }
        if (this.weightDecayApplyLearningRate() != other.weightDecayApplyLearningRate()) {
            return false;
        }
        if (this.numEpochs() != other.numEpochs()) {
            return false;
        }
        if (Double.compare(this.learningRate(), other.learningRate()) != 0) {
            return false;
        }
        if (this.debugMode() != other.debugMode()) {
            return false;
        }
        if (this.verboseMode() != other.verboseMode()) {
            return false;
        }
        String this$modelUri = this.modelUri();
        String other$modelUri = other.modelUri();
        if (this$modelUri == null ? other$modelUri != null : !this$modelUri.equals(other$modelUri)) {
            return false;
        }
        String this$modelSaveOutputPath = this.modelSaveOutputPath();
        String other$modelSaveOutputPath = other.modelSaveOutputPath();
        if (this$modelSaveOutputPath == null ? other$modelSaveOutputPath != null : !this$modelSaveOutputPath.equals(other$modelSaveOutputPath)) {
            return false;
        }
        List<String> this$lossVariables = this.lossVariables();
        List<String> other$lossVariables = other.lossVariables();
        if (this$lossVariables == null ? other$lossVariables != null : !((Object)this$lossVariables).equals(other$lossVariables)) {
            return false;
        }
        List<String> this$inputFeatures = this.inputFeatures();
        List<String> other$inputFeatures = other.inputFeatures();
        if (this$inputFeatures == null ? other$inputFeatures != null : !((Object)this$inputFeatures).equals(other$inputFeatures)) {
            return false;
        }
        List<String> this$labels = this.labels();
        List<String> other$labels = other.labels();
        if (this$labels == null ? other$labels != null : !((Object)this$labels).equals(other$labels)) {
            return false;
        }
        List<String> this$targetVariables = this.targetVariables();
        List<String> other$targetVariables = other.targetVariables();
        if (this$targetVariables == null ? other$targetVariables != null : !((Object)this$targetVariables).equals(other$targetVariables)) {
            return false;
        }
        IUpdater this$updater = this.updater();
        IUpdater other$updater = other.updater();
        if (this$updater == null ? other$updater != null : !this$updater.equals(other$updater)) {
            return false;
        }
        ISchedule this$learningRateSchedule = this.learningRateSchedule();
        ISchedule other$learningRateSchedule = other.learningRateSchedule();
        if (this$learningRateSchedule == null ? other$learningRateSchedule != null : !this$learningRateSchedule.equals(other$learningRateSchedule)) {
            return false;
        }
        DataType this$initialLossType = this.initialLossType();
        DataType other$initialLossType = other.initialLossType();
        if (this$initialLossType == null ? other$initialLossType != null : !this$initialLossType.equals(other$initialLossType)) {
            return false;
        }
        LossFunctions.LossFunction this$lossFunction = this.lossFunction();
        LossFunctions.LossFunction other$lossFunction = other.lossFunction();
        return !(this$lossFunction == null ? other$lossFunction != null : !this$lossFunction.equals(other$lossFunction));
    }

    protected boolean canEqual(Object other) {
        return other instanceof SameDiffTrainerStep;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $l1 = Double.doubleToLongBits(this.l1());
        result = result * 59 + (int)($l1 >>> 32 ^ $l1);
        long $l2 = Double.doubleToLongBits(this.l2());
        result = result * 59 + (int)($l2 >>> 32 ^ $l2);
        long $weightDecayCoefficient = Double.doubleToLongBits(this.weightDecayCoefficient());
        result = result * 59 + (int)($weightDecayCoefficient >>> 32 ^ $weightDecayCoefficient);
        result = result * 59 + (this.weightDecayApplyLearningRate() ? 79 : 97);
        result = result * 59 + this.numEpochs();
        long $learningRate = Double.doubleToLongBits(this.learningRate());
        result = result * 59 + (int)($learningRate >>> 32 ^ $learningRate);
        result = result * 59 + (this.debugMode() ? 79 : 97);
        result = result * 59 + (this.verboseMode() ? 79 : 97);
        String $modelUri = this.modelUri();
        result = result * 59 + ($modelUri == null ? 43 : $modelUri.hashCode());
        String $modelSaveOutputPath = this.modelSaveOutputPath();
        result = result * 59 + ($modelSaveOutputPath == null ? 43 : $modelSaveOutputPath.hashCode());
        List<String> $lossVariables = this.lossVariables();
        result = result * 59 + ($lossVariables == null ? 43 : ((Object)$lossVariables).hashCode());
        List<String> $inputFeatures = this.inputFeatures();
        result = result * 59 + ($inputFeatures == null ? 43 : ((Object)$inputFeatures).hashCode());
        List<String> $labels = this.labels();
        result = result * 59 + ($labels == null ? 43 : ((Object)$labels).hashCode());
        List<String> $targetVariables = this.targetVariables();
        result = result * 59 + ($targetVariables == null ? 43 : ((Object)$targetVariables).hashCode());
        IUpdater $updater = this.updater();
        result = result * 59 + ($updater == null ? 43 : $updater.hashCode());
        ISchedule $learningRateSchedule = this.learningRateSchedule();
        result = result * 59 + ($learningRateSchedule == null ? 43 : $learningRateSchedule.hashCode());
        DataType $initialLossType = this.initialLossType();
        result = result * 59 + ($initialLossType == null ? 43 : $initialLossType.hashCode());
        LossFunctions.LossFunction $lossFunction = this.lossFunction();
        result = result * 59 + ($lossFunction == null ? 43 : $lossFunction.hashCode());
        return result;
    }

    public String toString() {
        return "SameDiffTrainerStep(modelUri=" + this.modelUri() + ", l1=" + this.l1() + ", l2=" + this.l2() + ", weightDecayCoefficient=" + this.weightDecayCoefficient() + ", weightDecayApplyLearningRate=" + this.weightDecayApplyLearningRate() + ", modelSaveOutputPath=" + this.modelSaveOutputPath() + ", numEpochs=" + this.numEpochs() + ", lossVariables=" + this.lossVariables() + ", inputFeatures=" + this.inputFeatures() + ", labels=" + this.labels() + ", targetVariables=" + this.targetVariables() + ", updater=" + this.updater() + ", learningRate=" + this.learningRate() + ", learningRateSchedule=" + this.learningRateSchedule() + ", initialLossType=" + this.initialLossType() + ", lossFunction=" + this.lossFunction() + ", debugMode=" + this.debugMode() + ", verboseMode=" + this.verboseMode() + ")";
    }

    public SameDiffTrainerStep() {
    }

    private static final class SameDiffTrainerStepBuilderImpl
    extends SameDiffTrainerStepBuilder<SameDiffTrainerStep, SameDiffTrainerStepBuilderImpl> {
        private SameDiffTrainerStepBuilderImpl() {
        }

        @Override
        protected SameDiffTrainerStepBuilderImpl self() {
            return this;
        }

        @Override
        public SameDiffTrainerStep build() {
            return new SameDiffTrainerStep(this);
        }
    }

    public static abstract class SameDiffTrainerStepBuilder<C extends SameDiffTrainerStep, B extends SameDiffTrainerStepBuilder<C, B>> {
        private String modelUri;
        private double l1;
        private double l2;
        private double weightDecayCoefficient;
        private boolean weightDecayApplyLearningRate;
        private String modelSaveOutputPath;
        private int numEpochs;
        private List<String> lossVariables;
        private List<String> inputFeatures;
        private List<String> labels;
        private List<String> targetVariables;
        private IUpdater updater;
        private double learningRate;
        private ISchedule learningRateSchedule;
        private DataType initialLossType;
        private LossFunctions.LossFunction lossFunction;
        private boolean debugMode;
        private boolean verboseMode;

        protected abstract B self();

        public abstract C build();

        public B modelUri(String modelUri) {
            this.modelUri = modelUri;
            return this.self();
        }

        public B l1(double l1) {
            this.l1 = l1;
            return this.self();
        }

        public B l2(double l2) {
            this.l2 = l2;
            return this.self();
        }

        public B weightDecayCoefficient(double weightDecayCoefficient) {
            this.weightDecayCoefficient = weightDecayCoefficient;
            return this.self();
        }

        public B weightDecayApplyLearningRate(boolean weightDecayApplyLearningRate) {
            this.weightDecayApplyLearningRate = weightDecayApplyLearningRate;
            return this.self();
        }

        public B modelSaveOutputPath(String modelSaveOutputPath) {
            this.modelSaveOutputPath = modelSaveOutputPath;
            return this.self();
        }

        public B numEpochs(int numEpochs) {
            this.numEpochs = numEpochs;
            return this.self();
        }

        public B lossVariables(List<String> lossVariables) {
            this.lossVariables = lossVariables;
            return this.self();
        }

        public B inputFeatures(List<String> inputFeatures) {
            this.inputFeatures = inputFeatures;
            return this.self();
        }

        public B labels(List<String> labels) {
            this.labels = labels;
            return this.self();
        }

        public B targetVariables(List<String> targetVariables) {
            this.targetVariables = targetVariables;
            return this.self();
        }

        public B updater(IUpdater updater) {
            this.updater = updater;
            return this.self();
        }

        public B learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this.self();
        }

        public B learningRateSchedule(ISchedule learningRateSchedule) {
            this.learningRateSchedule = learningRateSchedule;
            return this.self();
        }

        public B initialLossType(DataType initialLossType) {
            this.initialLossType = initialLossType;
            return this.self();
        }

        public B lossFunction(LossFunctions.LossFunction lossFunction) {
            this.lossFunction = lossFunction;
            return this.self();
        }

        public B debugMode(boolean debugMode) {
            this.debugMode = debugMode;
            return this.self();
        }

        public B verboseMode(boolean verboseMode) {
            this.verboseMode = verboseMode;
            return this.self();
        }

        public String toString() {
            return "SameDiffTrainerStep.SameDiffTrainerStepBuilder(modelUri=" + this.modelUri + ", l1=" + this.l1 + ", l2=" + this.l2 + ", weightDecayCoefficient=" + this.weightDecayCoefficient + ", weightDecayApplyLearningRate=" + this.weightDecayApplyLearningRate + ", modelSaveOutputPath=" + this.modelSaveOutputPath + ", numEpochs=" + this.numEpochs + ", lossVariables=" + this.lossVariables + ", inputFeatures=" + this.inputFeatures + ", labels=" + this.labels + ", targetVariables=" + this.targetVariables + ", updater=" + this.updater + ", learningRate=" + this.learningRate + ", learningRateSchedule=" + this.learningRateSchedule + ", initialLossType=" + this.initialLossType + ", lossFunction=" + this.lossFunction + ", debugMode=" + this.debugMode + ", verboseMode=" + this.verboseMode + ")";
        }
    }
}

