/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.GradientCollector;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.loss.Loss;
import java.util.List;

public interface Trainer
extends AutoCloseable {
    public void initialize(Shape ... var1);

    default public Iterable<Batch> iterateDataset(Dataset dataset) {
        return dataset.getData(this.getManager());
    }

    public GradientCollector newGradientCollector();

    public void trainBatch(Batch var1);

    public NDList forward(NDList var1);

    public void validateBatch(Batch var1);

    public void step();

    public Metrics getMetrics();

    public void setMetrics(Metrics var1);

    public List<Device> getDevices();

    public void endEpoch();

    public Loss getLoss();

    public Model getModel();

    public List<Evaluator> getEvaluators();

    public <T extends Evaluator> T getEvaluator(Class<T> var1);

    public NDManager getManager();

    @Override
    public void close();
}

