/*
 * Decompiled with CFR 0.152.
 */
package sklearn.linear_model;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.Model;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ScalarLabelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.ClassDictUtil;
import sklearn.SkLearnRegressor;

public class LinearRegressor
extends SkLearnRegressor {
    public LinearRegressor(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        int[] shape = this.getCoefShape();
        if (shape.length == 2) {
            return shape[1];
        }
        return shape[0];
    }

    @Override
    public int getNumberOfOutputs() {
        int[] shape = this.getCoefShape();
        if (shape.length == 2) {
            return shape[0];
        }
        return 1;
    }

    @Override
    public Model encodeModel(Schema schema) {
        List<? extends Number> coef = this.getCoef();
        List<? extends Number> intercept = this.getIntercept();
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        int numberOfOutputs = this.getNumberOfOutputs();
        if (numberOfOutputs == 1) {
            return this.createRegression(coef, (Number)Iterables.getOnlyElement(intercept), schema);
        }
        if (numberOfOutputs >= 2) {
            List scalarLabels = ScalarLabelUtil.toScalarLabels((Label)label);
            ClassDictUtil.checkSize((int)numberOfOutputs, (Collection[])new Collection[]{intercept, scalarLabels});
            ArrayList<RegressionModel> models = new ArrayList<RegressionModel>();
            int max = numberOfOutputs;
            for (int i = 0; i < max; ++i) {
                Schema segmentSchema = schema.toRelabeledSchema((Label)scalarLabels.get(i));
                RegressionModel model = this.createRegression(CMatrixUtil.getRow(coef, (int)numberOfOutputs, (int)features.size(), (int)i), intercept.get(i), segmentSchema);
                models.add(model);
            }
            return MiningModelUtil.createMultiModelChain(models, (Segmentation.MissingPredictionTreatment)Segmentation.MissingPredictionTreatment.CONTINUE);
        }
        throw new IllegalArgumentException();
    }

    protected RegressionModel createRegression(List<? extends Number> coef, Number intercept, Schema schema) {
        return RegressionModelUtil.createRegression((List)schema.getFeatures(), coef, (Number)intercept, null, (Schema)schema);
    }

    public List<? extends Number> getCoef() {
        return this.getNumberArray("coef_");
    }

    public int[] getCoefShape() {
        return this.getArrayShape("coef_");
    }

    public List<? extends Number> getIntercept() {
        return this.getNumberArray("intercept_");
    }
}

