/*
 * Decompiled with CFR 0.152.
 */
package ml.shifu.guagua.example.nn;

import java.io.File;
import java.util.Iterator;
import ml.shifu.guagua.util.SizeEstimator;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.data.buffer.BufferedMLDataSet;

public class MemoryDiskMLDataSet
implements MLDataSet {
    private long maxByteSize = Long.MAX_VALUE;
    private long byteSize = 0L;
    private MLDataSet memoryDataSet;
    private MLDataSet diskDataSet;
    private int inputCount;
    private int outputCount;
    private String fileName;
    private long memoryCount = 0L;
    private long diskCount = 0L;

    public MemoryDiskMLDataSet(String fileName, int inputCount, int outputCount) {
        this.memoryDataSet = new BasicMLDataSet();
        this.inputCount = inputCount;
        this.outputCount = outputCount;
        this.fileName = fileName;
    }

    public MemoryDiskMLDataSet(long maxByteSize, String fileName) {
        this.maxByteSize = maxByteSize;
        this.memoryDataSet = new BasicMLDataSet();
        this.fileName = fileName;
    }

    public MemoryDiskMLDataSet(long maxByteSize, String fileName, int inputCount, int outputCount) {
        this.maxByteSize = maxByteSize;
        this.memoryDataSet = new BasicMLDataSet();
        this.inputCount = inputCount;
        this.outputCount = outputCount;
        this.fileName = fileName;
    }

    public final void beginLoad(int inputSize, int idealSize) {
        this.inputCount = inputSize;
        this.outputCount = idealSize;
        if (this.diskDataSet != null) {
            ((BufferedMLDataSet)this.diskDataSet).beginLoad(this.inputCount, this.outputCount);
        }
    }

    public final void endLoad() {
        if (this.diskDataSet != null) {
            ((BufferedMLDataSet)this.diskDataSet).endLoad();
        }
    }

    public Iterator<MLDataPair> iterator() {
        return new Iterator<MLDataPair>(){
            private Iterator<MLDataPair> iter1;
            private Iterator<MLDataPair> iter2;
            private boolean isMemoryHasNext;
            private boolean isDiskHasNext;
            {
                this.iter1 = MemoryDiskMLDataSet.this.memoryDataSet.iterator();
                this.iter2 = MemoryDiskMLDataSet.this.diskDataSet == null ? null : MemoryDiskMLDataSet.this.diskDataSet.iterator();
                this.isMemoryHasNext = false;
                this.isDiskHasNext = false;
            }

            @Override
            public boolean hasNext() {
                boolean hasNext = this.iter1.hasNext();
                if (hasNext) {
                    this.isMemoryHasNext = true;
                    this.isDiskHasNext = false;
                    return hasNext;
                }
                boolean bl = hasNext = this.iter2 == null ? false : this.iter2.hasNext();
                if (hasNext) {
                    this.isMemoryHasNext = false;
                    this.isDiskHasNext = true;
                } else {
                    this.isMemoryHasNext = false;
                    this.isDiskHasNext = false;
                }
                return hasNext;
            }

            @Override
            public MLDataPair next() {
                if (this.isMemoryHasNext) {
                    return this.iter1.next();
                }
                if (this.isDiskHasNext && this.iter2 != null) {
                    return this.iter2.next();
                }
                return null;
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }

    public int getIdealSize() {
        return this.outputCount;
    }

    public int getInputSize() {
        return this.inputCount;
    }

    public boolean isSupervised() {
        return this.memoryDataSet.isSupervised();
    }

    public long getRecordCount() {
        long count = this.memoryDataSet.getRecordCount();
        if (this.diskDataSet != null) {
            count += this.diskDataSet.getRecordCount();
        }
        return count;
    }

    public void getRecord(long index, MLDataPair pair) {
        if (index < this.memoryCount) {
            this.memoryDataSet.getRecord(index, pair);
        } else {
            this.diskDataSet.getRecord(index - this.memoryCount, pair);
        }
    }

    public MLDataSet openAdditional() {
        throw new UnsupportedOperationException();
    }

    public void add(MLData data) {
        long currentSize = SizeEstimator.estimate((Object)data);
        if (this.byteSize + currentSize < this.maxByteSize) {
            this.byteSize += currentSize;
            ++this.memoryCount;
            this.memoryDataSet.add(data);
        } else {
            if (this.diskDataSet == null) {
                this.diskDataSet = new BufferedMLDataSet(new File(this.fileName));
                ((BufferedMLDataSet)this.diskDataSet).beginLoad(this.inputCount, this.outputCount);
            }
            this.byteSize += currentSize;
            ++this.diskCount;
            this.diskDataSet.add(data);
        }
    }

    public void add(MLData inputData, MLData idealData) {
        long currentSize = SizeEstimator.estimate((Object)inputData) + SizeEstimator.estimate((Object)idealData);
        if (this.byteSize + currentSize < this.maxByteSize) {
            this.byteSize += currentSize;
            ++this.memoryCount;
            this.memoryDataSet.add(inputData, idealData);
        } else {
            if (this.diskDataSet == null) {
                this.diskDataSet = new BufferedMLDataSet(new File(this.fileName));
                ((BufferedMLDataSet)this.diskDataSet).beginLoad(this.inputCount, this.outputCount);
            }
            this.byteSize += currentSize;
            ++this.diskCount;
            this.diskDataSet.add(inputData, idealData);
        }
    }

    public void add(MLDataPair inputData) {
        long currentSize = SizeEstimator.estimate((Object)inputData);
        if (this.byteSize + currentSize < this.maxByteSize) {
            this.byteSize += currentSize;
            ++this.memoryCount;
            this.memoryDataSet.add(inputData);
        } else {
            if (this.diskDataSet == null) {
                this.diskDataSet = new BufferedMLDataSet(new File(this.fileName));
                ((BufferedMLDataSet)this.diskDataSet).beginLoad(this.inputCount, this.outputCount);
            }
            this.byteSize += currentSize;
            ++this.diskCount;
            this.diskDataSet.add(inputData);
        }
    }

    public void close() {
        this.memoryDataSet.close();
        if (this.diskDataSet != null) {
            this.diskDataSet.close();
        }
    }

    public long getMemoryCount() {
        return this.memoryCount;
    }

    public long getDiskCount() {
        return this.diskCount;
    }

    public static void main(String[] args) {
        MLDataPair next;
        long start;
        double[] input = MemoryDiskMLDataSet.createInput(1.0);
        double[] output = new double[]{1.0};
        BasicMLDataPair pair = new BasicMLDataPair((MLData)new BasicMLData(input), (MLData)new BasicMLData(output));
        MemoryDiskMLDataSet dataSet = new MemoryDiskMLDataSet(400L, "a.txt");
        dataSet.beginLoad(10, 1);
        dataSet.add((MLDataPair)pair);
        BasicMLDataPair pair2 = new BasicMLDataPair((MLData)new BasicMLData(MemoryDiskMLDataSet.createInput(2.0)), (MLData)new BasicMLData(output));
        BasicMLDataPair pair3 = new BasicMLDataPair((MLData)new BasicMLData(MemoryDiskMLDataSet.createInput(3.0)), (MLData)new BasicMLData(output));
        BasicMLDataPair pair4 = new BasicMLDataPair((MLData)new BasicMLData(MemoryDiskMLDataSet.createInput(4.0)), (MLData)new BasicMLData(output));
        BasicMLDataPair pair5 = new BasicMLDataPair((MLData)new BasicMLData(MemoryDiskMLDataSet.createInput(5.0)), (MLData)new BasicMLData(output));
        BasicMLDataPair pair6 = new BasicMLDataPair((MLData)new BasicMLData(MemoryDiskMLDataSet.createInput(6.0)), (MLData)new BasicMLData(output));
        dataSet.add((MLDataPair)pair2);
        dataSet.add((MLDataPair)pair3);
        dataSet.add((MLDataPair)pair4);
        dataSet.add((MLDataPair)pair5);
        dataSet.add((MLDataPair)pair6);
        dataSet.endLoad();
        long recordCount = dataSet.getRecordCount();
        for (long i = 0L; i < recordCount; ++i) {
            long start2 = System.currentTimeMillis();
            BasicMLDataPair p = new BasicMLDataPair((MLData)new BasicMLData(MemoryDiskMLDataSet.createInput(6.0)), (MLData)new BasicMLData(output));
            dataSet.getRecord(i, (MLDataPair)p);
            System.out.println(System.currentTimeMillis() - start2 + " " + p);
        }
        System.out.println();
        Iterator<MLDataPair> iterator = dataSet.iterator();
        while (iterator.hasNext()) {
            start = System.currentTimeMillis();
            next = iterator.next();
            System.out.println(System.currentTimeMillis() - start + " " + next);
        }
        System.out.println();
        iterator = dataSet.iterator();
        while (iterator.hasNext()) {
            start = System.currentTimeMillis();
            next = iterator.next();
            System.out.println(System.currentTimeMillis() - start + " " + next);
        }
        dataSet.close();
        long size = SizeEstimator.estimate((Object)pair);
        System.out.println(size);
    }

    private static double[] createInput(double d) {
        double[] input = new double[10];
        for (int i = 0; i < input.length; ++i) {
            input[i] = d;
        }
        return input;
    }
}

