/*
 * Decompiled with CFR 0.152.
 */
package org.pointyware.disco.training.entities;

import androidx.compose.runtime.internal.StabilityInferred;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Unit;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.Reflection;
import kotlin.jvm.internal.SourceDebugExtension;
import kotlin.reflect.KClass;
import kotlinx.coroutines.flow.FlowKt;
import kotlinx.coroutines.flow.MutableStateFlow;
import kotlinx.coroutines.flow.StateFlow;
import kotlinx.coroutines.flow.StateFlowKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.pointyware.disco.entities.layers.DenseLayer;
import org.pointyware.disco.entities.layers.Layer;
import org.pointyware.disco.entities.loss.LossFunction;
import org.pointyware.disco.entities.networks.SequentialNetwork;
import org.pointyware.disco.entities.tensors.Tensor;
import org.pointyware.disco.entities.tensors.TensorPool;
import org.pointyware.disco.training.entities.BatchStatistics;
import org.pointyware.disco.training.entities.ComputationContext;
import org.pointyware.disco.training.entities.ComputationKey;
import org.pointyware.disco.training.entities.EpochStatistics;
import org.pointyware.disco.training.entities.Exercise;
import org.pointyware.disco.training.entities.LayerStatistics;
import org.pointyware.disco.training.entities.SampleStatistics;
import org.pointyware.disco.training.entities.Snapshot;
import org.pointyware.disco.training.entities.Statistics;
import org.pointyware.disco.training.entities.Trainer;
import org.pointyware.disco.training.entities.optimizers.MultiPassOptimizer;
import org.pointyware.disco.training.entities.optimizers.Optimizer;
import org.pointyware.disco.training.entities.optimizers.SinglePassOptimizer;
import org.pointyware.disco.training.entities.optimizers.StatisticalOptimizer;
import org.pointyware.disco.training.interactors.TrainingControllerKt;

@Metadata(mv={2, 2, 0}, k=1, xi=48, d1={"\u0000\u0086\u0001\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0006\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\f\n\u0002\u0010\u000b\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0002\b\u0003\b\u0007\u0018\u00002\u00020\u0001B?\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0006\u0010\t\u001a\u00020\n\u0012\u0006\u0010\u000b\u001a\u00020\f\u0012\b\b\u0002\u0010\r\u001a\u00020\u000e\u00a2\u0006\u0004\b\u000f\u0010\u0010B/\b\u0016\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0006\u0010\t\u001a\u00020\u0011\u00a2\u0006\u0004\b\u000f\u0010\u0012J\u0010\u00106\u001a\u0002052\u0006\u00107\u001a\u000205H\u0016R\u0011\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0013\u0010\u0014R\u0017\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0015\u0010\u0016R\u0011\u0010\u0007\u001a\u00020\b\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0017\u0010\u0018R\u0011\u0010\t\u001a\u00020\n\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0019\u0010\u001aR\u0011\u0010\u000b\u001a\u00020\f\u00a2\u0006\b\n\u0000\u001a\u0004\b\u001b\u0010\u001cR\u000e\u0010\r\u001a\u00020\u000eX\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u000e\u0010\u001d\u001a\u00020\u001eX\u0082\u000e\u00a2\u0006\u0002\n\u0000R\u0010\u0010\u001f\u001a\u0004\u0018\u00010 X\u0082\u000e\u00a2\u0006\u0002\n\u0000R\u0010\u0010!\u001a\u0004\u0018\u00010\"X\u0082\u000e\u00a2\u0006\u0002\n\u0000R\u0010\u0010#\u001a\u0004\u0018\u00010$X\u0082\u000e\u00a2\u0006\u0002\n\u0000R\u0010\u0010%\u001a\u0004\u0018\u00010&X\u0082\u000e\u00a2\u0006\u0002\n\u0000R\u0010\u0010'\u001a\u0004\u0018\u00010(X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0010\u0010)\u001a\u0004\u0018\u00010*X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0014\u0010+\u001a\b\u0012\u0004\u0012\u00020-0,X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u001a\u0010.\u001a\b\u0012\u0004\u0012\u00020-0/8VX\u0096\u0004\u00a2\u0006\u0006\u001a\u0004\b0\u00101R\u000e\u00102\u001a\u000203X\u0082\u000e\u00a2\u0006\u0002\n\u0000R\u000e\u00104\u001a\u000205X\u0082\u000e\u00a2\u0006\u0002\n\u0000\u00a8\u00068"}, d2={"Lorg/pointyware/disco/training/entities/SequentialTrainer;", "Lorg/pointyware/disco/training/entities/Trainer;", "network", "Lorg/pointyware/disco/entities/networks/SequentialNetwork;", "cases", "", "Lorg/pointyware/disco/training/entities/Exercise;", "lossFunction", "Lorg/pointyware/disco/entities/loss/LossFunction;", "optimizer", "Lorg/pointyware/disco/training/entities/optimizers/Optimizer;", "statistics", "Lorg/pointyware/disco/training/entities/Statistics;", "acceptableError", "", "<init>", "(Lorg/pointyware/disco/entities/networks/SequentialNetwork;Ljava/util/List;Lorg/pointyware/disco/entities/loss/LossFunction;Lorg/pointyware/disco/training/entities/optimizers/Optimizer;Lorg/pointyware/disco/training/entities/Statistics;D)V", "Lorg/pointyware/disco/training/entities/optimizers/StatisticalOptimizer;", "(Lorg/pointyware/disco/entities/networks/SequentialNetwork;Ljava/util/List;Lorg/pointyware/disco/entities/loss/LossFunction;Lorg/pointyware/disco/training/entities/optimizers/StatisticalOptimizer;)V", "getNetwork", "()Lorg/pointyware/disco/entities/networks/SequentialNetwork;", "getCases", "()Ljava/util/List;", "getLossFunction", "()Lorg/pointyware/disco/entities/loss/LossFunction;", "getOptimizer", "()Lorg/pointyware/disco/training/entities/optimizers/Optimizer;", "getStatistics", "()Lorg/pointyware/disco/training/entities/Statistics;", "done", "", "epochStatistics", "Lorg/pointyware/disco/training/entities/EpochStatistics;", "batchStatistics", "Lorg/pointyware/disco/training/entities/BatchStatistics;", "sampleStatistics", "Lorg/pointyware/disco/training/entities/SampleStatistics;", "layerStatistics", "Lorg/pointyware/disco/training/entities/LayerStatistics;", "singlePassOptimizer", "Lorg/pointyware/disco/training/entities/optimizers/SinglePassOptimizer;", "multiPassOptimizer", "Lorg/pointyware/disco/training/entities/optimizers/MultiPassOptimizer;", "_snapshot", "Lkotlinx/coroutines/flow/MutableStateFlow;", "Lorg/pointyware/disco/training/entities/Snapshot;", "snapshot", "Lkotlinx/coroutines/flow/StateFlow;", "getSnapshot", "()Lkotlinx/coroutines/flow/StateFlow;", "tensorPool", "Lorg/pointyware/disco/entities/tensors/TensorPool;", "epoch", "", "train", "iterations", "feature-training"})
@StabilityInferred(parameters=0)
@SourceDebugExtension(value={"SMAP\nSequentialTrainer.kt\nKotlin\n*S Kotlin\n*F\n+ 1 SequentialTrainer.kt\norg/pointyware/disco/training/entities/SequentialTrainer\n+ 2 ComputationKey.kt\norg/pointyware/disco/training/entities/ComputationKeyKt\n+ 3 _Collections.kt\nkotlin/collections/CollectionsKt___CollectionsKt\n+ 4 Tensor.kt\norg/pointyware/disco/entities/tensors/Tensor\n*L\n1#1,209:1\n14#2,2:210\n14#2,2:212\n14#2,2:214\n14#2,2:216\n1563#3:218\n1634#3,3:219\n1563#3:222\n1634#3,3:223\n1563#3:226\n1634#3,3:227\n1563#3:230\n1634#3,3:231\n1869#3,2:234\n1869#3,2:236\n1869#3:238\n1869#3,2:239\n1869#3,2:241\n1870#3:243\n1869#3:244\n1870#3:249\n1869#3:250\n1870#3:255\n1878#3,3:256\n1878#3,3:259\n1878#3,3:262\n1869#3,2:265\n1869#3,2:267\n1869#3,2:269\n1869#3,2:271\n179#4,4:245\n179#4,4:251\n*S KotlinDebug\n*F\n+ 1 SequentialTrainer.kt\norg/pointyware/disco/training/entities/SequentialTrainer\n*L\n73#1:210,2\n74#1:212,2\n75#1:214,2\n76#1:216,2\n79#1:218\n79#1:219,3\n81#1:222\n81#1:223,3\n85#1:226\n85#1:227,3\n87#1:230\n87#1:231,3\n104#1:234,2\n105#1:236,2\n112#1:238\n114#1:239,2\n115#1:241,2\n112#1:243\n144#1:244\n144#1:249\n145#1:250\n145#1:255\n149#1:256,3\n159#1:259,3\n169#1:262,3\n201#1:265,2\n202#1:267,2\n203#1:269,2\n204#1:271,2\n144#1:245,4\n145#1:251,4\n*E\n"})
public final class SequentialTrainer
implements Trainer {
    @NotNull
    private final SequentialNetwork network;
    @NotNull
    private final List<Exercise> cases;
    @NotNull
    private final LossFunction lossFunction;
    @NotNull
    private final Optimizer optimizer;
    @NotNull
    private final Statistics statistics;
    private final double acceptableError;
    private boolean done;
    @Nullable
    private EpochStatistics epochStatistics;
    @Nullable
    private BatchStatistics batchStatistics;
    @Nullable
    private SampleStatistics sampleStatistics;
    @Nullable
    private LayerStatistics layerStatistics;
    @Nullable
    private final SinglePassOptimizer singlePassOptimizer;
    @Nullable
    private final MultiPassOptimizer multiPassOptimizer;
    @NotNull
    private final MutableStateFlow<Snapshot> _snapshot;
    @NotNull
    private TensorPool tensorPool;
    private int epoch;
    public static final int $stable = 8;

    public SequentialTrainer(@NotNull SequentialNetwork network, @NotNull List<Exercise> cases, @NotNull LossFunction lossFunction, @NotNull Optimizer optimizer, @NotNull Statistics statistics, double acceptableError) {
        Intrinsics.checkNotNullParameter((Object)network, (String)"network");
        Intrinsics.checkNotNullParameter(cases, (String)"cases");
        Intrinsics.checkNotNullParameter((Object)lossFunction, (String)"lossFunction");
        Intrinsics.checkNotNullParameter((Object)optimizer, (String)"optimizer");
        Intrinsics.checkNotNullParameter((Object)statistics, (String)"statistics");
        this.network = network;
        this.cases = cases;
        this.lossFunction = lossFunction;
        this.optimizer = optimizer;
        this.statistics = statistics;
        this.acceptableError = acceptableError;
        Object object = this.statistics;
        this.epochStatistics = object instanceof EpochStatistics ? (EpochStatistics)object : null;
        object = this.statistics;
        this.batchStatistics = object instanceof BatchStatistics ? (BatchStatistics)object : null;
        object = this.statistics;
        this.sampleStatistics = object instanceof SampleStatistics ? (SampleStatistics)object : null;
        object = this.statistics;
        this.layerStatistics = object instanceof LayerStatistics ? (LayerStatistics)object : null;
        object = this.optimizer;
        this.singlePassOptimizer = object instanceof SinglePassOptimizer ? (SinglePassOptimizer)object : null;
        object = this.optimizer;
        this.multiPassOptimizer = object instanceof MultiPassOptimizer ? (MultiPassOptimizer)object : null;
        if (!(this.singlePassOptimizer != null || this.multiPassOptimizer != null)) {
            boolean bl = false;
            String string2 = "The given optimizer (" + this.optimizer + ") must implement SinglePassOptimizer or MultiPassOptimizer.";
            throw new IllegalArgumentException(string2.toString());
        }
        this._snapshot = StateFlowKt.MutableStateFlow((Object)Snapshot.Companion.getEmpty());
        this.tensorPool = new TensorPool();
    }

    public /* synthetic */ SequentialTrainer(SequentialNetwork sequentialNetwork, List list, LossFunction lossFunction, Optimizer optimizer, Statistics statistics, double d, int n, DefaultConstructorMarker defaultConstructorMarker) {
        if ((n & 0x20) != 0) {
            d = 0.001;
        }
        this(sequentialNetwork, list, lossFunction, optimizer, statistics, d);
    }

    @NotNull
    public final SequentialNetwork getNetwork() {
        return this.network;
    }

    @NotNull
    public final List<Exercise> getCases() {
        return this.cases;
    }

    @NotNull
    public final LossFunction getLossFunction() {
        return this.lossFunction;
    }

    @NotNull
    public final Optimizer getOptimizer() {
        return this.optimizer;
    }

    @NotNull
    public final Statistics getStatistics() {
        return this.statistics;
    }

    public SequentialTrainer(@NotNull SequentialNetwork network, @NotNull List<Exercise> cases, @NotNull LossFunction lossFunction, @NotNull StatisticalOptimizer optimizer) {
        Intrinsics.checkNotNullParameter((Object)network, (String)"network");
        Intrinsics.checkNotNullParameter(cases, (String)"cases");
        Intrinsics.checkNotNullParameter((Object)lossFunction, (String)"lossFunction");
        Intrinsics.checkNotNullParameter((Object)optimizer, (String)"optimizer");
        this(network, cases, lossFunction, optimizer, optimizer, 0.0, 32, null);
    }

    @Override
    @NotNull
    public StateFlow<Snapshot> getSnapshot() {
        return FlowKt.asStateFlow(this._snapshot);
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public int train(int iterations) {
        void $this$mapTo$iv$iv;
        void $this$mapTo$iv$iv2;
        void $this$mapTo$iv$iv3;
        Collection collection;
        void $this$mapTo$iv$iv4;
        if (this.done) {
            return 0;
        }
        ComputationContext computationContext = new ComputationContext();
        long $this$key_u24default$iv = 0L;
        KClass type$iv = Reflection.getOrCreateKotlinClass(List.class);
        boolean $i$f$key = false;
        ComputationKey weightGradientsKey = new ComputationKey($this$key_u24default$iv, type$iv);
        long $this$key_u24default$iv2 = 1L;
        KClass type$iv2 = Reflection.getOrCreateKotlinClass(List.class);
        boolean $i$f$key2 = false;
        ComputationKey biasGradientsKey = new ComputationKey($this$key_u24default$iv2, type$iv2);
        long $this$key_u24default$iv3 = 2L;
        KClass type$iv3 = Reflection.getOrCreateKotlinClass(List.class);
        boolean $i$f$key3 = false;
        ComputationKey activationsKey = new ComputationKey($this$key_u24default$iv3, type$iv3);
        long $this$key_u24default$iv4 = 3L;
        KClass type$iv4 = Reflection.getOrCreateKotlinClass(List.class);
        boolean $i$f$key2232 = false;
        ComputationKey derivativeActivationsKey = new ComputationKey($this$key_u24default$iv4, type$iv4);
        Iterable $this$map$iv = this.network.getLayers();
        boolean $i$f$map = false;
        Iterable $i$f$key2232 = $this$map$iv;
        Iterable destination$iv$iv = new ArrayList(CollectionsKt.collectionSizeOrDefault((Iterable)$this$map$iv, (int)10));
        boolean $i$f$mapTo = false;
        for (Object item$iv$iv : $this$mapTo$iv$iv4) {
            void it;
            DenseLayer denseLayer = (DenseLayer)item$iv$iv;
            collection = destination$iv$iv;
            boolean bl = false;
            collection.add((Tensor)this.tensorPool.getObject((Object)ArraysKt.toList((int[])it.getWeights().getDimensions())));
        }
        List weightGradients = (List)destination$iv$iv;
        Iterable $this$map$iv2 = this.network.getLayers();
        boolean $i$f$map2 = false;
        destination$iv$iv = $this$map$iv2;
        Iterable destination$iv$iv2 = new ArrayList(CollectionsKt.collectionSizeOrDefault((Iterable)$this$map$iv2, (int)10));
        boolean $i$f$mapTo2 = false;
        for (Object item$iv$iv : $this$mapTo$iv$iv3) {
            void it;
            DenseLayer bl = (DenseLayer)item$iv$iv;
            collection = destination$iv$iv2;
            boolean bl2 = false;
            collection.add((Tensor)this.tensorPool.getObject((Object)ArraysKt.toList((int[])it.getBiases().getDimensions())));
        }
        List biasGradients = (List)destination$iv$iv2;
        Iterable $this$map$iv3 = this.network.getLayers();
        boolean $i$f$map3 = false;
        destination$iv$iv2 = $this$map$iv3;
        Iterable destination$iv$iv3 = new ArrayList(CollectionsKt.collectionSizeOrDefault((Iterable)$this$map$iv3, (int)10));
        boolean $i$f$mapTo3 = false;
        for (Object item$iv$iv : $this$mapTo$iv$iv2) {
            void it;
            DenseLayer bl2 = (DenseLayer)item$iv$iv;
            collection = destination$iv$iv3;
            boolean bl = false;
            collection.add((Tensor)this.tensorPool.getObject((Object)ArraysKt.toList((int[])it.getBiases().getDimensions())));
        }
        List activations = (List)destination$iv$iv3;
        Iterable $this$map$iv4 = this.network.getLayers();
        int $i$f$map4 = 0;
        destination$iv$iv3 = $this$map$iv4;
        Collection destination$iv$iv4 = new ArrayList(CollectionsKt.collectionSizeOrDefault((Iterable)$this$map$iv4, (int)10));
        int $i$f$mapTo4 = 0;
        for (Object item$iv$iv : $this$mapTo$iv$iv) {
            void it;
            DenseLayer bl = (DenseLayer)item$iv$iv;
            collection = destination$iv$iv4;
            boolean bl3 = false;
            collection.add((Tensor)this.tensorPool.getObject((Object)ArraysKt.toList((int[])it.getBiases().getDimensions())));
        }
        List derivativeActivations = (List)destination$iv$iv4;
        computationContext.put(weightGradientsKey, weightGradients);
        computationContext.put(biasGradientsKey, biasGradients);
        computationContext.put(activationsKey, activations);
        computationContext.put(derivativeActivationsKey, derivativeActivations);
        Snapshot latestSnapshot = null;
        for ($i$f$map4 = 0; $i$f$map4 < iterations; ++$i$f$map4) {
            boolean bl;
            int index = $i$f$map4;
            boolean bl4 = false;
            $i$f$mapTo4 = this.epoch;
            this.epoch = $i$f$mapTo4 + 1;
            EpochStatistics epochStatistics = this.epochStatistics;
            if (epochStatistics != null) {
                epochStatistics.onEpochStart(this.epoch, computationContext);
                v1 = Unit.INSTANCE;
            } else {
                v1 = null;
            }
            int step = 0;
            double epochLossSum = 0.0;
            do {
                Tensor it;
                ++step;
                List<List<Exercise>> sampleBatches = this.optimizer.batch(this.cases);
                Iterable $this$forEach$iv = weightGradients;
                boolean $i$f$forEach = false;
                for (Object t : $this$forEach$iv) {
                    it = (Tensor)t;
                    boolean bl5 = false;
                    it.zero();
                }
                $this$forEach$iv = biasGradients;
                $i$f$forEach = false;
                for (Object t : $this$forEach$iv) {
                    it = (Tensor)t;
                    boolean bl6 = false;
                    it.zero();
                }
                double batchLossSum = 0.0;
                for (List list : sampleBatches) {
                    Unit unit;
                    boolean $i$f$forEachIndexed;
                    Iterable $this$forEachIndexed$iv;
                    void var43_117;
                    Tensor tensor;
                    Iterator iterator;
                    boolean $i$f$mapEach;
                    Tensor this_$iv;
                    Tensor gradient;
                    int n;
                    BatchStatistics batchStatistics = this.batchStatistics;
                    if (batchStatistics != null) {
                        batchStatistics.onBatchStart(list);
                        v3 = Unit.INSTANCE;
                    } else {
                        v3 = null;
                    }
                    float caseCount = list.size();
                    double sampleLossSum = 0.0;
                    Iterable $this$forEach$iv2 = list;
                    boolean $i$f$forEach2 = false;
                    for (Object element$iv : $this$forEach$iv2) {
                        Unit unit2;
                        Exercise exercise = (Exercise)element$iv;
                        boolean bl7 = false;
                        Iterable $this$forEach$iv3 = activations;
                        boolean $i$f$forEach3 = false;
                        for (Object element$iv2 : $this$forEach$iv3) {
                            Tensor it3 = (Tensor)element$iv2;
                            boolean bl8 = false;
                            it3.zero();
                        }
                        Iterable $this$forEach$iv4 = derivativeActivations;
                        boolean $i$f$forEach4 = false;
                        for (Object element$iv2 : $this$forEach$iv4) {
                            Tensor it4 = (Tensor)element$iv2;
                            n = 0;
                            it4.zero();
                        }
                        SampleStatistics sampleStatistics = this.sampleStatistics;
                        if (sampleStatistics != null) {
                            sampleStatistics.onSampleStart(exercise);
                            v5 = Unit.INSTANCE;
                        } else {
                            v5 = null;
                        }
                        this.network.forward(exercise.getInput(), activations, derivativeActivations);
                        Tensor output = (Tensor)CollectionsKt.last((List)activations);
                        double sampleLoss = this.lossFunction.compute(exercise.getOutput(), output);
                        SampleStatistics sampleStatistics2 = this.sampleStatistics;
                        if (sampleStatistics2 != null) {
                            sampleStatistics2.onCost(sampleLoss);
                            v7 = Unit.INSTANCE;
                        } else {
                            v7 = null;
                        }
                        sampleLossSum += sampleLoss;
                        Tensor errorGradient = this.lossFunction.derivative(exercise.getOutput(), output);
                        this.network.backward(exercise.getInput(), errorGradient, activations, derivativeActivations, weightGradients, biasGradients);
                        SampleStatistics sampleStatistics3 = this.sampleStatistics;
                        if (sampleStatistics3 != null) {
                            sampleStatistics3.onGradient();
                            v9 = Unit.INSTANCE;
                        } else {
                            v9 = null;
                        }
                        SampleStatistics sampleStatistics4 = this.sampleStatistics;
                        if (sampleStatistics4 != null) {
                            sampleStatistics4.onSampleEnd(exercise);
                            unit2 = Unit.INSTANCE;
                            continue;
                        }
                        unit2 = null;
                    }
                    Iterable $this$forEach$iv5 = weightGradients;
                    boolean $i$f$forEach5 = false;
                    for (Object element$iv : $this$forEach$iv5) {
                        gradient = (Tensor)element$iv;
                        boolean bl9 = false;
                        this_$iv = gradient;
                        $i$f$mapEach = false;
                        iterator = this_$iv.getFlatIndices();
                        while (iterator.hasNext()) {
                            void it5;
                            int index$iv = ((Number)iterator.next()).intValue();
                            float it4 = this_$iv.get(index$iv);
                            n = index$iv;
                            tensor = this_$iv;
                            boolean bl10 = false;
                            var43_117 = it5 / caseCount;
                            tensor.set(n, (float)var43_117);
                        }
                    }
                    $this$forEach$iv5 = biasGradients;
                    $i$f$forEach5 = false;
                    for (Object element$iv : $this$forEach$iv5) {
                        gradient = (Tensor)element$iv;
                        boolean bl11 = false;
                        this_$iv = gradient;
                        $i$f$mapEach = false;
                        iterator = this_$iv.getFlatIndices();
                        while (iterator.hasNext()) {
                            void it6;
                            int index$iv = ((Number)iterator.next()).intValue();
                            float it5 = this_$iv.get(index$iv);
                            n = index$iv;
                            tensor = this_$iv;
                            boolean bl12 = false;
                            var43_117 = it6 / caseCount;
                            tensor.set(n, (float)var43_117);
                        }
                    }
                    if (this.singlePassOptimizer != null) {
                        boolean bl13 = false;
                        $this$forEachIndexed$iv = this.network.getLayers();
                        $i$f$forEachIndexed = false;
                        int index$iv = 0;
                        for (Object item$iv : $this$forEachIndexed$iv) {
                            void layer;
                            SinglePassOptimizer it2;
                            int n2;
                            if ((n2 = index$iv++) < 0) {
                                CollectionsKt.throwIndexOverflow();
                            }
                            DenseLayer it6 = (DenseLayer)item$iv;
                            int index2 = n2;
                            boolean bl14 = false;
                            it2.update(this.epoch, (Layer)layer, (Tensor)weightGradients.get(index2), (Tensor)biasGradients.get(index2));
                        }
                        v12 = Unit.INSTANCE;
                    } else {
                        v12 = null;
                    }
                    if (this.multiPassOptimizer != null) {
                        boolean bl15 = false;
                        $this$forEachIndexed$iv = this.network.getLayers();
                        $i$f$forEachIndexed = false;
                        int index$iv2 = 0;
                        for (Object item$iv : $this$forEachIndexed$iv) {
                            MultiPassOptimizer it2;
                            int n3;
                            if ((n3 = index$iv2++) < 0) {
                                CollectionsKt.throwIndexOverflow();
                            }
                            DenseLayer layer = (DenseLayer)item$iv;
                            int index2 = n3;
                            boolean bl16 = false;
                            it2.update(step, this.epoch, (Layer)layer, (Tensor)weightGradients.get(index2), (Tensor)biasGradients.get(index2));
                        }
                        v13 = Unit.INSTANCE;
                    } else {
                        v13 = null;
                    }
                    Iterable $this$forEachIndexed$iv2 = this.network.getLayers();
                    boolean $i$f$forEachIndexed2 = false;
                    int index$iv = 0;
                    for (Object item$iv : $this$forEachIndexed$iv2) {
                        Unit unit3;
                        int n4;
                        if ((n4 = index$iv++) < 0) {
                            CollectionsKt.throwIndexOverflow();
                        }
                        DenseLayer index$iv2 = (DenseLayer)item$iv;
                        int index3 = n4;
                        boolean bl17 = false;
                        SinglePassOptimizer singlePassOptimizer = this.singlePassOptimizer;
                        if (singlePassOptimizer != null) {
                            void layer;
                            singlePassOptimizer.update(this.epoch, (Layer)layer, (Tensor)weightGradients.get(index3), (Tensor)biasGradients.get(index3));
                            unit3 = Unit.INSTANCE;
                            continue;
                        }
                        unit3 = null;
                    }
                    double averageSampleLoss = sampleLossSum / (double)list.size();
                    batchLossSum += averageSampleLoss;
                    BatchStatistics batchStatistics2 = this.batchStatistics;
                    if (batchStatistics2 != null) {
                        batchStatistics2.onBatchEnd(list);
                        unit = Unit.INSTANCE;
                        continue;
                    }
                    unit = null;
                }
                double averageBatchLoss = batchLossSum / (double)sampleBatches.size();
                epochLossSum += averageBatchLoss;
                MultiPassOptimizer multiPassOptimizer = this.multiPassOptimizer;
                if (multiPassOptimizer != null) {
                    if (multiPassOptimizer.passAgain()) {
                        bl = true;
                        continue;
                    }
                    bl = false;
                    continue;
                }
                bl = false;
            } while (bl);
            double averageLoss = epochLossSum / (double)step;
            computationContext.put(TrainingControllerKt.getLossKey(), Float.valueOf((float)averageLoss));
            EpochStatistics epochStatistics2 = this.epochStatistics;
            if (epochStatistics2 != null) {
                epochStatistics2.onEpochEnd(this.epoch, computationContext);
                v21 = Unit.INSTANCE;
            } else {
                v21 = null;
            }
            latestSnapshot = this.statistics.createSnapshot();
            this._snapshot.setValue((Object)latestSnapshot);
            float error = ((Number)computationContext.get(TrainingControllerKt.getLossKey())).floatValue();
            if (!((double)error < this.acceptableError)) continue;
            this.done = true;
            return index + 1;
        }
        Iterable $this$forEach$iv = weightGradients;
        boolean $i$f$forEach = false;
        for (Object element$iv : $this$forEach$iv) {
            Tensor it = (Tensor)element$iv;
            boolean bl = false;
            this.tensorPool.returnObject((Object)it);
        }
        $this$forEach$iv = biasGradients;
        $i$f$forEach = false;
        for (Object element$iv : $this$forEach$iv) {
            Tensor it = (Tensor)element$iv;
            boolean bl = false;
            this.tensorPool.returnObject((Object)it);
        }
        $this$forEach$iv = activations;
        $i$f$forEach = false;
        for (Object element$iv : $this$forEach$iv) {
            Tensor it = (Tensor)element$iv;
            boolean bl = false;
            this.tensorPool.returnObject((Object)it);
        }
        $this$forEach$iv = derivativeActivations;
        $i$f$forEach = false;
        for (Object element$iv : $this$forEach$iv) {
            Tensor it = (Tensor)element$iv;
            boolean bl = false;
            this.tensorPool.returnObject((Object)it);
        }
        return iterations;
    }
}

