/*
 * Decompiled with CFR 0.152.
 */
package sklearn.compose;

import java.util.Collections;
import numpy.core.UFunc;
import numpy.core.UFuncUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.AbstractTransformation;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import sklearn.FieldNameUtil;
import sklearn.Regressor;
import sklearn.preprocessing.FunctionTransformer;

public class TransformedTargetRegressor
extends Regressor {
    public TransformedTargetRegressor(String module, String name) {
        super(module, name);
    }

    @Override
    public Model encodeModel(Schema schema) {
        Regressor regressor = this.getRegressor();
        FunctionTransformer transformer = this.getTransformer();
        UFunc func = transformer.getFunc();
        final UFunc inverseFunc = transformer.getInverseFunc();
        if (inverseFunc == null) {
            return regressor.encodeModel(schema);
        }
        Label label = schema.getLabel();
        AbstractTransformation transformation = new AbstractTransformation(){

            public FieldName getName(FieldName name) {
                return FieldNameUtil.create("inverseFunc", name);
            }

            public Expression createExpression(FieldRef fieldRef) {
                return UFuncUtil.encodeUFunc((UFunc)inverseFunc, Collections.singletonList(fieldRef));
            }
        };
        FieldName name = label.getName();
        Schema segmentSchema = schema.toAnonymousSchema();
        Model model = regressor.encodeModel(segmentSchema).setOutput(ModelUtil.createPredictedOutput((FieldName)FieldNameUtil.create("func", name), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[]{transformation}));
        return MiningModelUtil.createRegression((Model)model, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.NONE, (Schema)schema);
    }

    public Regressor getRegressor() {
        return (Regressor)this.get("regressor_", Regressor.class);
    }

    public FunctionTransformer getTransformer() {
        return (FunctionTransformer)this.get("transformer_", FunctionTransformer.class);
    }
}

