/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.dataset;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.util.Progress;
import java.io.IOException;
import java.util.stream.Stream;

public class ArrayDataset
extends RandomAccessDataset {
    protected NDArray[] data;
    protected NDArray[] labels;

    public ArrayDataset(RandomAccessDataset.BaseBuilder<?> builder) {
        super(builder);
        if (builder instanceof Builder) {
            Builder builder2 = (Builder)builder;
            this.data = builder2.data;
            this.labels = builder2.labels;
            long size = this.data[0].size(0);
            if (Stream.of(this.data).anyMatch(array -> array.size(0) != size)) {
                throw new IllegalArgumentException("All the NDArray must have the same length!");
            }
            if (this.labels != null && Stream.of(this.labels).anyMatch(array -> array.size(0) != size)) {
                throw new IllegalArgumentException("All the NDArray must have the same length!");
            }
        }
    }

    @Override
    protected long availableSize() {
        return this.data[0].size(0);
    }

    @Override
    public Record get(NDManager manager, long index) {
        NDList datum = new NDList();
        NDList label = new NDList();
        for (NDArray array : this.data) {
            datum.add(array.get(index));
        }
        if (this.labels != null) {
            for (NDArray array : this.labels) {
                label.add(array.get(index));
            }
        }
        datum.attach(manager);
        label.attach(manager);
        return new Record(datum, label);
    }

    @Override
    public void prepare(Progress progress) throws IOException {
    }

    public static final class Builder
    extends RandomAccessDataset.BaseBuilder<Builder> {
        private NDArray[] data;
        private NDArray[] labels;

        @Override
        protected Builder self() {
            return this;
        }

        public Builder setData(NDArray ... data) {
            this.data = data;
            return this.self();
        }

        public Builder optLabels(NDArray ... labels) {
            this.labels = labels;
            return this.self();
        }

        public ArrayDataset build() {
            if (this.data == null || this.data.length == 0) {
                throw new IllegalArgumentException("Please pass in at least one data");
            }
            return new ArrayDataset(this);
        }
    }
}

