/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java.example;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.example.util.DataLoader;

public class BasicWalkThrough {
    public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
        if (fPredicts.length != sPredicts.length) {
            return false;
        }
        for (int i = 0; i < fPredicts.length; ++i) {
            if (Arrays.equals(fPredicts[i], sPredicts[i])) continue;
            return false;
        }
        return true;
    }

    public static void saveDumpModel(String modelPath, String[] modelInfos) throws IOException {
        try {
            PrintWriter writer = new PrintWriter(modelPath, "UTF-8");
            for (int i = 0; i < modelInfos.length; ++i) {
                writer.print("booster[" + i + "]:\n");
                writer.print(modelInfos[i]);
            }
            writer.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] args) throws IOException, XGBoostError {
        DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
        DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("eta", 1.0);
        params.put("max_depth", 2);
        params.put("silent", 1);
        params.put("objective", "binary:logistic");
        HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
        watches.put("train", trainMat);
        watches.put("test", testMat);
        int round = 2;
        Booster booster = XGBoost.train((DMatrix)trainMat, params, (int)round, watches, null, null);
        float[][] predicts = booster.predict(testMat);
        File file = new File("./model");
        if (!file.exists()) {
            file.mkdirs();
        }
        String modelPath = "./model/xgb.model";
        booster.saveModel(modelPath);
        String[] modelInfos = booster.getModelDump("../../demo/data/featmap.txt", false);
        BasicWalkThrough.saveDumpModel("./model/dump.raw.txt", modelInfos);
        testMat.saveBinary("./model/dtest.buffer");
        Booster booster2 = XGBoost.loadModel((String)"./model/xgb.model");
        DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
        float[][] predicts2 = booster2.predict(testMat2);
        System.out.println(BasicWalkThrough.checkPredicts(predicts, predicts2));
        System.out.println("start build dmatrix from csr sparse data ...");
        DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
        DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR);
        trainMat2.setLabel(spData.labels);
        HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>();
        watches2.put("train", trainMat2);
        watches2.put("test", testMat2);
        Booster booster3 = XGBoost.train((DMatrix)trainMat2, params, (int)round, watches2, null, null);
        float[][] predicts3 = booster3.predict(testMat2);
        System.out.println(BasicWalkThrough.checkPredicts(predicts, predicts3));
    }
}

