/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost.testing;

import com.google.common.base.Equivalence;
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import org.dmg.pmml.Header;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Timestamp;
import org.jpmml.converter.testing.ModelEncoderBatch;
import org.jpmml.converter.testing.OptionsUtil;
import org.jpmml.evaluator.ResultField;
import org.jpmml.model.ReflectionUtil;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.Learner;
import org.jpmml.xgboost.XGBoostUtil;
import org.jpmml.xgboost.testing.XGBoostEncoderBatchTest;

public abstract class XGBoostEncoderBatch
extends ModelEncoderBatch {
    private String format = System.getProperty(XGBoostEncoderBatch.class.getName() + ".format", "model,json");

    public XGBoostEncoderBatch(String algorithm, String dataset, Predicate<ResultField> columnFilter, Equivalence<Object> equivalence) {
        super(algorithm, dataset, columnFilter, equivalence);
    }

    public abstract XGBoostEncoderBatchTest getArchiveBatchTest();

    public List<Map<String, Object>> getOptionsMatrix() {
        String dataset = this.getDataset();
        Integer ntreeLimit = null;
        int index = dataset.indexOf(64);
        if (index > -1) {
            ntreeLimit = new Integer(dataset.substring(index + 1));
        }
        LinkedHashMap<String, Object> options = new LinkedHashMap<String, Object>();
        options.put("compact", new Boolean[]{false, true});
        options.put("prune", true);
        options.put("nan_as_missing", true);
        options.put("ntree_limit", ntreeLimit);
        return OptionsUtil.generateOptionsMatrix(options);
    }

    public String getLearnerPath(String format) {
        return "/xgboost/" + this.getAlgorithm() + XGBoostEncoderBatch.truncate((String)this.getDataset()) + "." + format;
    }

    public String getFeatureMapPath() {
        return "/csv/" + XGBoostEncoderBatch.truncate((String)this.getDataset()) + ".fmap";
    }

    public PMML getPMML() throws Exception {
        String[] formats;
        PMML result = null;
        for (String format : formats = this.format.split(",")) {
            PMML pmml = this.loadPMML(this.getLearnerPath(format), this.getFeatureMapPath());
            if (result != null) {
                this.assertEquals(result, pmml);
            }
            result = pmml;
        }
        return result;
    }

    public String getInputCsvPath() {
        return "/csv/" + XGBoostEncoderBatch.truncate((String)this.getDataset()) + ".csv";
    }

    public String getOutputCsvPath() {
        return super.getOutputCsvPath();
    }

    protected PMML loadPMML(String learnerPath, String featureMapPath) throws Exception {
        FeatureMap featureMap;
        Learner learner;
        try (InputStream is = this.open(learnerPath);){
            learner = XGBoostUtil.loadLearner(is);
        }
        try (InputStream is = this.open(featureMapPath);){
            featureMap = XGBoostUtil.loadFeatureMap(is);
        }
        Map options = this.getOptions();
        PMML pmml = learner.encodePMML(options, null, null, featureMap);
        this.validatePMML(pmml);
        return pmml;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void assertEquals(PMML left, PMML right) {
        Header leftHeader = left.requireHeader();
        Header rightHeader = right.requireHeader();
        Timestamp leftTimestamp = leftHeader.getTimestamp();
        Timestamp rightTimestamp = rightHeader.getTimestamp();
        try {
            leftHeader.setTimestamp(null);
            rightHeader.setTimestamp(null);
            boolean equals = ReflectionUtil.equals((PMMLObject)left, (PMMLObject)right);
            if (!equals) {
                throw new AssertionError();
            }
        }
        finally {
            leftHeader.setTimestamp(leftTimestamp);
            rightHeader.setTimestamp(rightTimestamp);
        }
    }
}

