/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java.example;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class CustomObjective {
    public static void main(String[] args) throws XGBoostError {
        DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
        DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
        HashMap<String, Number> params = new HashMap<String, Number>();
        params.put("eta", 1.0);
        params.put("max_depth", 2);
        params.put("silent", 1);
        int round = 2;
        HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
        watches.put("train", trainMat);
        watches.put("test", testMat);
        LogRegObj obj = new LogRegObj();
        EvalError eval2 = new EvalError();
        System.out.println("begin to train the booster model");
        Booster booster = XGBoost.train((DMatrix)trainMat, params, (int)round, watches, (IObjective)obj, (IEvaluation)eval2);
    }

    public static class EvalError
    implements IEvaluation {
        private static final Log logger = LogFactory.getLog(EvalError.class);
        String evalMetric = "custom_error";

        public String getMetric() {
            return this.evalMetric;
        }

        public float eval(float[][] predicts, DMatrix dmat) {
            float[] labels;
            float error = 0.0f;
            try {
                labels = dmat.getLabel();
            }
            catch (XGBoostError ex) {
                logger.error((Object)ex);
                return -1.0f;
            }
            int nrow = predicts.length;
            for (int i = 0; i < nrow; ++i) {
                if (labels[i] == 0.0f && predicts[i][0] > 0.0f) {
                    error += 1.0f;
                    continue;
                }
                if (labels[i] != 1.0f || !(predicts[i][0] <= 0.0f)) continue;
                error += 1.0f;
            }
            return error / (float)labels.length;
        }
    }

    public static class LogRegObj
    implements IObjective {
        private static final Log logger = LogFactory.getLog(LogRegObj.class);

        public float sigmoid(float input) {
            float val = (float)(1.0 / (1.0 + Math.exp(-input)));
            return val;
        }

        public float[][] transform(float[][] predicts) {
            int nrow = predicts.length;
            float[][] transPredicts = new float[nrow][1];
            for (int i = 0; i < nrow; ++i) {
                transPredicts[i][0] = this.sigmoid(predicts[i][0]);
            }
            return transPredicts;
        }

        public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
            float[] labels;
            int nrow = predicts.length;
            ArrayList<float[]> gradients = new ArrayList<float[]>();
            try {
                labels = dtrain.getLabel();
            }
            catch (XGBoostError ex) {
                logger.error((Object)ex);
                return null;
            }
            float[] grad = new float[nrow];
            float[] hess = new float[nrow];
            float[][] transPredicts = this.transform(predicts);
            for (int i = 0; i < nrow; ++i) {
                float predict = transPredicts[i][0];
                grad[i] = predict - labels[i];
                hess[i] = predict * (1.0f - predict);
            }
            gradients.add(grad);
            gradients.add(hess);
            return gradients;
        }
    }
}

