/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.knime.visitors;

import java.util.Deque;
import java.util.List;
import java.util.Objects;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.VisitorAction;
import org.jpmml.evaluator.IndexableUtil;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.UnsupportedFeatureException;
import org.jpmml.model.visitors.AbstractModelVisitor;

public class RegressionTargetCorrector
extends AbstractModelVisitor {
    private Target.CastInteger castInteger = null;

    public RegressionTargetCorrector() {
        this(Target.CastInteger.ROUND);
    }

    public RegressionTargetCorrector(Target.CastInteger castInteger) {
        this.setCastInteger(Objects.requireNonNull(castInteger));
    }

    public VisitorAction visit(Model model) {
        MiningFunction miningFunction = model.getMiningFunction();
        switch (miningFunction) {
            case REGRESSION: {
                this.processRegressionModel(model);
                break;
            }
        }
        return VisitorAction.CONTINUE;
    }

    private void processRegressionModel(Model model) {
        DataDictionary dataDictionary;
        PMML pmml = this.getPMML();
        MiningField miningField = RegressionTargetCorrector.getTargetField(model);
        if (miningField == null) {
            return;
        }
        FieldName name = miningField.getName();
        DataField dataField = (DataField)IndexableUtil.find((Object)name, (List)(dataDictionary = pmml.getDataDictionary()).getDataFields());
        if (dataField == null) {
            throw new MissingFieldException(name, (PMMLObject)miningField);
        }
        DataType dataType = dataField.getDataType();
        switch (dataType) {
            case INTEGER: {
                break;
            }
            case FLOAT: 
            case DOUBLE: {
                return;
            }
            default: {
                throw new UnsupportedFeatureException((PMMLObject)dataField, (Enum)dataType);
            }
        }
        Targets targets = model.getTargets();
        if (targets != null) {
            Target target = (Target)IndexableUtil.find((Object)name, (List)targets.getTargets());
            if (target != null) {
                if (target.getCastInteger() != null) {
                    return;
                }
                target.setCastInteger(this.getCastInteger());
            } else {
                targets.addTargets(new Target[]{this.createTarget(name)});
            }
        } else {
            targets = new Targets().addTargets(new Target[]{this.createTarget(name)});
            model.setTargets(targets);
        }
    }

    private Target createTarget(FieldName name) {
        Target target = new Target().setField(name).setCastInteger(this.getCastInteger());
        return target;
    }

    private PMML getPMML() {
        Deque parents = this.getParents();
        return (PMML)parents.getLast();
    }

    public Target.CastInteger getCastInteger() {
        return this.castInteger;
    }

    private void setCastInteger(Target.CastInteger castInteger) {
        this.castInteger = castInteger;
    }

    private static MiningField getTargetField(Model model) {
        MiningSchema miningSchema = model.getMiningSchema();
        MiningField result = null;
        List miningFields = miningSchema.getMiningFields();
        for (MiningField miningField : miningFields) {
            MiningField.UsageType usageType = miningField.getUsageType();
            switch (usageType) {
                case TARGET: 
                case PREDICTED: {
                    if (result != null) {
                        throw new UnsupportedFeatureException((PMMLObject)miningSchema);
                    }
                    result = miningField;
                    break;
                }
            }
        }
        return result;
    }
}

