/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.GradientCollector;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.listener.TrainingListener;
import java.util.concurrent.ConcurrentHashMap;

public final class EasyTrain {
    private EasyTrain() {
    }

    public static void fit(Trainer trainer, int numEpoch, Dataset trainingDataset, Dataset validateDataset) {
        for (int epoch = 0; epoch < numEpoch; ++epoch) {
            for (Batch batch : trainer.iterateDataset(trainingDataset)) {
                EasyTrain.trainBatch(trainer, batch);
                trainer.step();
                batch.close();
            }
            if (validateDataset != null) {
                for (Batch batch : trainer.iterateDataset(validateDataset)) {
                    EasyTrain.validateBatch(trainer, batch);
                    batch.close();
                }
            }
            trainer.notifyListeners(listener -> listener.onEpoch(trainer));
        }
    }

    public static void trainBatch(Trainer trainer, Batch batch) {
        if (trainer.getManager().getEngine() != batch.getManager().getEngine()) {
            throw new IllegalArgumentException("The data must be on the same engine as the trainer. You may need to change one of your NDManagers.");
        }
        Batch[] splits = batch.split(trainer.getDevices(), false);
        TrainingListener.BatchData batchData = new TrainingListener.BatchData(batch, new ConcurrentHashMap<Device, NDList>(), new ConcurrentHashMap<Device, NDList>());
        try (GradientCollector collector = trainer.newGradientCollector();){
            for (Batch split : splits) {
                NDList data = split.getData();
                NDList labels = split.getLabels();
                NDList preds = trainer.forward(data, labels);
                long time = System.nanoTime();
                NDArray lossValue = trainer.getLoss().evaluate(labels, preds);
                collector.backward(lossValue);
                trainer.addMetric("backward", time);
                time = System.nanoTime();
                batchData.getLabels().put(((NDArray)labels.get(0)).getDevice(), labels);
                batchData.getPredictions().put(((NDArray)preds.get(0)).getDevice(), preds);
                trainer.addMetric("training-metrics", time);
            }
        }
        trainer.notifyListeners(listener -> listener.onTrainingBatch(trainer, batchData));
    }

    public static void validateBatch(Trainer trainer, Batch batch) {
        if (trainer.getManager().getEngine() != batch.getManager().getEngine()) {
            throw new IllegalArgumentException("The data must be on the same engine as the trainer. You may need to change one of your NDManagers.");
        }
        Batch[] splits = batch.split(trainer.getDevices(), false);
        TrainingListener.BatchData batchData = new TrainingListener.BatchData(batch, new ConcurrentHashMap<Device, NDList>(), new ConcurrentHashMap<Device, NDList>());
        for (Batch split : splits) {
            NDList data = split.getData();
            NDList labels = split.getLabels();
            NDList preds = trainer.forward(data, labels);
            batchData.getLabels().put(((NDArray)labels.get(0)).getDevice(), labels);
            batchData.getPredictions().put(((NDArray)preds.get(0)).getDevice(), preds);
        }
        trainer.notifyListeners(listener -> listener.onValidationBatch(trainer, batchData));
    }
}

