package org.deeplearning4j.datasets.fetchers;

import au.com.bytecode.opencsv.CSV;
import au.com.bytecode.opencsv.CSVReadProc;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicInteger;
import org.deeplearning4j.berkeley.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/deeplearning4j/datasets/fetchers/CSVDataFetcher.class */
public class CSVDataFetcher extends BaseDataFetcher {
    private CSV csv;
    private InputStream is;
    private int labelColumn;
    private DataSet all;

    public CSVDataFetcher(InputStream inputStream, int i) {
        this.is = inputStream;
        this.labelColumn = i;
        this.csv = CSV.skipLines(0).separator(',').quote('\"').create();
        init();
    }

    public CSVDataFetcher(File file, int i) throws IOException {
        this(new BufferedInputStream(new FileInputStream(file)), i, 0);
    }

    public CSVDataFetcher(InputStream inputStream, int i, int i2) {
        this.is = inputStream;
        this.labelColumn = i;
        this.csv = CSV.skipLines(i2).separator(',').noQuote().create();
        init();
    }

    public CSVDataFetcher(File file, int i, int i2) throws IOException {
        this(new BufferedInputStream(new FileInputStream(file)), i, i2);
    }

    private void init() {
        final HashSet hashSet = new HashSet();
        final ArrayList arrayList = new ArrayList();
        final ArrayList arrayList2 = new ArrayList();
        final AtomicInteger atomicInteger = new AtomicInteger(-1);
        this.csv.read(this.is, new CSVReadProc() { // from class: org.deeplearning4j.datasets.fetchers.CSVDataFetcher.1
            public void procRow(int i, String... strArr) {
                if (strArr.length < 1) {
                    return;
                }
                if (atomicInteger.get() < 1) {
                    atomicInteger.set(strArr.length - 1);
                    CSVDataFetcher.this.inputColumns = strArr.length - 1;
                } else if (strArr.length - 1 != atomicInteger.get()) {
                    return;
                }
                Pair processRow = CSVDataFetcher.this.processRow(strArr);
                arrayList.add(processRow.getSecond());
                hashSet.add(processRow.getSecond());
                arrayList2.add(processRow.getFirst());
            }
        });
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList(hashSet);
        for (int i = 0; i < arrayList.size(); i++) {
            arrayList3.add(new DataSet((INDArray) arrayList2.get(i), FeatureUtil.toOutcomeVector(arrayList4.indexOf(arrayList.get(i)), hashSet.size())));
        }
        this.numOutcomes = hashSet.size();
        this.totalExamples = arrayList3.size();
        this.all = DataSet.merge(arrayList3);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Pair<INDArray, String> processRow(String[] strArr) {
        String replaceAll = strArr[this.labelColumn].replaceAll(".\".", "");
        double[] dArr = new double[strArr.length - 1];
        int i = 0;
        for (int i2 = 0; i2 < strArr.length; i2++) {
            if (i2 != this.labelColumn) {
                dArr[i] = Double.parseDouble(strArr[i2]);
                i++;
            }
        }
        return new Pair<>(Nd4j.create(dArr).reshape(1, dArr.length), replaceAll);
    }

    @Override // org.deeplearning4j.datasets.iterator.DataSetFetcher
    public void fetch(int i) {
        int i2 = this.cursor + i;
        if (i2 >= this.all.numExamples()) {
            i2 = this.all.numExamples();
        }
        initializeCurrFromList(this.all.asList().subList(this.cursor, i2));
        this.cursor += i;
    }
}
