/*
 * Decompiled with CFR 0.152.
 */
package org.pointyware.disco.training.entities.optimizers;

import androidx.compose.runtime.internal.StabilityInferred;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.Reflection;
import kotlin.jvm.internal.SourceDebugExtension;
import kotlin.random.Random;
import org.jetbrains.annotations.NotNull;
import org.pointyware.disco.entities.layers.DenseLayer;
import org.pointyware.disco.entities.layers.Layer;
import org.pointyware.disco.entities.tensors.Tensor;
import org.pointyware.disco.training.entities.Exercise;
import org.pointyware.disco.training.entities.optimizers.LearningRateSchedule;
import org.pointyware.disco.training.entities.optimizers.SinglePassOptimizer;

@Metadata(mv={2, 2, 0}, k=1, xi=48, d1={"\u0000>\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0000\n\u0002\u0010\b\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\b\u0017\u0018\u00002\u00020\u0001B\u0019\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\u0004\b\u0006\u0010\u0007J\"\u0010\f\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u000e0\r0\r2\f\u0010\u000f\u001a\b\u0012\u0004\u0012\u00020\u000e0\rH\u0016J(\u0010\u0010\u001a\u00020\u00112\u0006\u0010\u0012\u001a\u00020\u00132\u0006\u0010\u0014\u001a\u00020\u00152\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u0017H\u0016R\u0011\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\b\u0010\tR\u0011\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\b\n\u0000\u001a\u0004\b\n\u0010\u000b\u00a8\u0006\u0019"}, d2={"Lorg/pointyware/disco/training/entities/optimizers/GradientDescent;", "Lorg/pointyware/disco/training/entities/optimizers/SinglePassOptimizer;", "learningRate", "Lorg/pointyware/disco/training/entities/optimizers/LearningRateSchedule;", "entropy", "Lkotlin/random/Random;", "<init>", "(Lorg/pointyware/disco/training/entities/optimizers/LearningRateSchedule;Lkotlin/random/Random;)V", "getLearningRate", "()Lorg/pointyware/disco/training/entities/optimizers/LearningRateSchedule;", "getEntropy", "()Lkotlin/random/Random;", "batch", "", "Lorg/pointyware/disco/training/entities/Exercise;", "cases", "update", "", "epoch", "", "layer", "Lorg/pointyware/disco/entities/layers/Layer;", "weightGradients", "Lorg/pointyware/disco/entities/tensors/Tensor;", "biasGradients", "feature-training"})
@StabilityInferred(parameters=0)
@SourceDebugExtension(value={"SMAP\nGradientDescent.kt\nKotlin\n*S Kotlin\n*F\n+ 1 GradientDescent.kt\norg/pointyware/disco/training/entities/optimizers/GradientDescent\n+ 2 Tensor.kt\norg/pointyware/disco/entities/tensors/Tensor\n*L\n1#1,44:1\n195#2,4:45\n195#2,4:49\n*S KotlinDebug\n*F\n+ 1 GradientDescent.kt\norg/pointyware/disco/training/entities/optimizers/GradientDescent\n*L\n30#1:45,4\n33#1:49,4\n*E\n"})
public class GradientDescent
implements SinglePassOptimizer {
    @NotNull
    private final LearningRateSchedule learningRate;
    @NotNull
    private final Random entropy;
    public static final int $stable = 8;

    public GradientDescent(@NotNull LearningRateSchedule learningRate, @NotNull Random entropy) {
        Intrinsics.checkNotNullParameter((Object)learningRate, (String)"learningRate");
        Intrinsics.checkNotNullParameter((Object)entropy, (String)"entropy");
        this.learningRate = learningRate;
        this.entropy = entropy;
    }

    public /* synthetic */ GradientDescent(LearningRateSchedule learningRateSchedule, Random random, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 2) != 0) {
            random = (Random)Random.Default;
        }
        this(learningRateSchedule, random);
    }

    @NotNull
    public final LearningRateSchedule getLearningRate() {
        return this.learningRate;
    }

    @NotNull
    public final Random getEntropy() {
        return this.entropy;
    }

    @Override
    @NotNull
    public List<List<Exercise>> batch(@NotNull List<Exercise> cases) {
        Intrinsics.checkNotNullParameter(cases, (String)"cases");
        return CollectionsKt.listOf(cases);
    }

    @Override
    public void update(int epoch, @NotNull Layer layer, @NotNull Tensor weightGradients, @NotNull Tensor biasGradients) {
        Intrinsics.checkNotNullParameter((Object)layer, (String)"layer");
        Intrinsics.checkNotNullParameter((Object)weightGradients, (String)"weightGradients");
        Intrinsics.checkNotNullParameter((Object)biasGradients, (String)"biasGradients");
        float currentLearningRate = this.learningRate.learningRate(epoch);
        if (layer instanceof DenseLayer) {
            float f;
            int index;
            float value;
            Tensor tensor;
            int n;
            int i$iv;
            Tensor this_$iv = ((DenseLayer)layer).getWeights();
            boolean $i$f$mapEachFlatIndexed = false;
            Iterator iterator = this_$iv.getFlatIndices();
            while (iterator.hasNext()) {
                i$iv = ((Number)iterator.next()).intValue();
                float f2 = this_$iv.get(i$iv);
                int n2 = i$iv;
                n = i$iv;
                tensor = this_$iv;
                boolean bl = false;
                f = value - currentLearningRate * weightGradients.get(index);
                tensor.set(n, f);
            }
            this_$iv = ((DenseLayer)layer).getBiases();
            $i$f$mapEachFlatIndexed = false;
            iterator = this_$iv.getFlatIndices();
            while (iterator.hasNext()) {
                i$iv = ((Number)iterator.next()).intValue();
                value = this_$iv.get(i$iv);
                index = i$iv;
                n = i$iv;
                tensor = this_$iv;
                boolean bl = false;
                f = value - currentLearningRate * biasGradients.get(index);
                tensor.set(n, f);
            }
        } else {
            throw new IllegalArgumentException("Unsupported layer type: " + Reflection.getOrCreateKotlinClass(layer.getClass()).getSimpleName());
        }
    }
}

