/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation.sgd;

import java.util.Iterator;
import org.encog.EncogError;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;

public class BatchDataSet
implements MLDataSet {
    private MLDataSet dataset;
    private int currentIndex;
    private int batchSize;
    private GenerateRandom random;
    private boolean randomBatches;
    private int[] randomSample;

    public BatchDataSet(MLDataSet theDataset, GenerateRandom theRandom) {
        this.dataset = theDataset;
        this.random = theRandom;
        this.setBatchSize(500);
    }

    public void setBatchSize(int theSize) {
        this.batchSize = Math.min(theSize, this.dataset.size());
        this.randomSample = new int[this.batchSize];
        if (this.randomBatches) {
            this.generaterandomSample();
        }
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    @Override
    public Iterator<MLDataPair> iterator() {
        BatchedMLIterator result = new BatchedMLIterator();
        return result;
    }

    @Override
    public int getIdealSize() {
        return this.dataset.getIdealSize();
    }

    @Override
    public int getInputSize() {
        return this.dataset.getInputSize();
    }

    @Override
    public boolean isSupervised() {
        return this.dataset.isSupervised();
    }

    @Override
    public long getRecordCount() {
        return this.batchSize;
    }

    @Override
    public void getRecord(long index, MLDataPair pair) {
        this.dataset.getRecord((index + (long)this.currentIndex) % (long)this.dataset.size(), pair);
    }

    @Override
    public MLDataSet openAdditional() {
        BatchDataSet result = new BatchDataSet(this.dataset, new MersenneTwisterGenerateRandom(this.random.nextLong()));
        result.setBatchSize(this.getBatchSize());
        return result;
    }

    @Override
    public void add(MLData data1) {
        throw new EncogError("Unsupported.");
    }

    @Override
    public void add(MLData inputData, MLData idealData) {
        throw new EncogError("Unsupported.");
    }

    @Override
    public void add(MLDataPair inputData) {
        throw new EncogError("Unsupported.");
    }

    @Override
    public void close() {
    }

    @Override
    public int size() {
        return this.batchSize;
    }

    @Override
    public MLDataPair get(int index) {
        int resultIndex = (index + this.currentIndex) % this.dataset.size();
        if (this.randomBatches) {
            resultIndex = this.randomSample[resultIndex];
        }
        return this.dataset.get(resultIndex);
    }

    public void advance() {
        if (this.randomBatches) {
            this.generaterandomSample();
        } else {
            this.currentIndex = (this.currentIndex + this.batchSize) % this.dataset.size();
        }
    }

    public int getCurrentIndex() {
        return this.currentIndex;
    }

    public void setCurrentIndex(int currentIndex) {
        this.currentIndex = currentIndex;
    }

    public boolean isRandomBatches() {
        return this.randomBatches;
    }

    public void setRandomBatches(boolean randomBatches) {
        this.randomBatches = randomBatches;
    }

    private void generaterandomSample() {
        for (int i = 0; i < this.batchSize; ++i) {
            int t;
            boolean uniqueFound = true;
            block1: do {
                t = this.random.nextInt(0, this.dataset.size());
                for (int j = 0; j < i; ++j) {
                    if (this.randomSample[j] != t) continue;
                    uniqueFound = false;
                    continue block1;
                }
            } while (!uniqueFound);
            this.randomSample[i] = t;
        }
    }

    public class BatchedMLIterator
    implements Iterator<MLDataPair> {
        private int currentIndex = 0;

        @Override
        public final boolean hasNext() {
            return this.currentIndex < BatchDataSet.this.getBatchSize();
        }

        @Override
        public final MLDataPair next() {
            if (!this.hasNext()) {
                return null;
            }
            return BatchDataSet.this.get(this.currentIndex++);
        }

        @Override
        public final void remove() {
            throw new EncogError("Called remove, unsupported operation.");
        }
    }
}

