/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.lightgbm.testing;

import com.google.common.base.Equivalence;
import java.io.IOException;
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.testing.IntegrationTestBatch;
import org.jpmml.lightgbm.GBDT;
import org.jpmml.lightgbm.LightGBMUtil;
import org.jpmml.lightgbm.testing.LightGBMTest;
import org.jpmml.model.visitors.AbstractVisitor;
import org.junit.Assert;

public abstract class LightGBMTestBatch
extends IntegrationTestBatch {
    public LightGBMTestBatch(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence) {
        super(name, dataset, predicate, equivalence);
    }

    public abstract LightGBMTest getIntegrationTest();

    public Map<String, Object> getOptions() {
        String[] dataset = this.parseDataset();
        Integer numIteration = null;
        if (dataset.length > 1) {
            numIteration = new Integer(dataset[1]);
        }
        LinkedHashMap<String, Object> options = new LinkedHashMap<String, Object>();
        options.put("compact", numIteration != null);
        options.put("nan_as_missing", true);
        options.put("num_iteration", numIteration);
        return options;
    }

    public String getModelTxtPath() {
        String[] dataset = this.parseDataset();
        return "/lgbm/" + this.getName() + dataset[0] + ".txt";
    }

    public PMML getPMML() throws Exception {
        GBDT gbdt;
        try (InputStream is = this.open(this.getModelTxtPath());){
            gbdt = LightGBMUtil.loadGBDT((InputStream)is);
        }
        Map<String, Object> options = this.getOptions();
        PMML pmml = gbdt.encodePMML(options, null, null);
        this.validatePMML(pmml);
        return pmml;
    }

    public String getInputCsvPath() {
        String[] dataset = this.parseDataset();
        return "/csv/" + dataset[0] + ".csv";
    }

    public List<Map<String, String>> getInput() throws IOException {
        return this.loadRecords(this.getInputCsvPath());
    }

    public String getOutputCsvPath() {
        return "/csv/" + this.getName() + this.getDataset() + ".csv";
    }

    public List<Map<String, String>> getOutput() throws IOException {
        return this.loadRecords(this.getOutputCsvPath());
    }

    protected void validatePMML(PMML pmml) throws Exception {
        super.validatePMML(pmml);
        AbstractVisitor visitor = new AbstractVisitor(){

            public VisitorAction visit(MiningModel miningModel) {
                MiningSchema miningSchema;
                PMMLObject parent = this.getParent();
                if (parent instanceof PMML && (miningSchema = miningModel.getMiningSchema()).hasMiningFields()) {
                    List miningFields = miningSchema.getMiningFields();
                    for (MiningField miningField : miningFields) {
                        MiningField.UsageType usageType = miningField.getUsageType();
                        switch (usageType) {
                            case TARGET: {
                                Assert.assertNull((Object)miningField.getImportance());
                                break;
                            }
                            case ACTIVE: {
                                Assert.assertNotNull((Object)miningField.getImportance());
                                break;
                            }
                        }
                    }
                }
                return super.visit(miningModel);
            }
        };
        visitor.applyTo((Visitable)pmml);
    }

    protected String[] parseDataset() {
        String dataset = this.getDataset();
        int index = dataset.indexOf(64);
        if (index > -1) {
            return new String[]{dataset.substring(0, index), dataset.substring(index + 1)};
        }
        return new String[]{dataset};
    }
}

