/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.bundle.ops.classification;

import ml.combust.bundle.BundleContext;
import ml.combust.bundle.dsl.Bundle$BuiltinOps$classification$;
import ml.combust.bundle.dsl.HasAttributes;
import ml.combust.bundle.dsl.Model;
import ml.combust.bundle.dsl.NodeShape;
import ml.combust.bundle.dsl.Value;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.op.OpModel;
import ml.combust.bundle.op.OpModel$class;
import ml.combust.mleap.tensor.DenseTensor;
import ml.combust.mleap.tensor.Tensor;
import org.apache.spark.ml.bundle.ParamSpec;
import org.apache.spark.ml.bundle.ParamSpec$;
import org.apache.spark.ml.bundle.SimpleParamSpec;
import org.apache.spark.ml.bundle.SimpleSparkOp;
import org.apache.spark.ml.bundle.SparkBundleContext;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Matrices$;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.Param;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u000194A!\u0001\u0002\u0001#\t!Bj\\4jgRL7MU3he\u0016\u001c8/[8o\u001fBT!a\u0001\u0003\u0002\u001d\rd\u0017m]:jM&\u001c\u0017\r^5p]*\u0011QAB\u0001\u0004_B\u001c(BA\u0004\t\u0003\u0019\u0011WO\u001c3mK*\u0011\u0011BC\u0001\u0003[2T!a\u0003\u0007\u0002\u000bM\u0004\u0018M]6\u000b\u00055q\u0011AB1qC\u000eDWMC\u0001\u0010\u0003\ry'oZ\u0002\u0001'\t\u0001!\u0003E\u0002\u0014)Yi\u0011AB\u0005\u0003+\u0019\u0011QbU5na2,7\u000b]1sW>\u0003\bCA\f\u001a\u001b\u0005A\"BA\u0002\t\u0013\tQ\u0002DA\fM_\u001eL7\u000f^5d%\u0016<'/Z:tS>tWj\u001c3fY\")A\u0004\u0001C\u0001;\u00051A(\u001b8jiz\"\u0012A\b\t\u0003?\u0001i\u0011A\u0001\u0005\bC\u0001\u0011\r\u0011\"\u0004#\u0003\u0015bujR%T)&\u001buLU#H%\u0016\u001b6+S(O?\u0012+e)Q+M)~#\u0006JU#T\u0011>cE)F\u0001$\u001f\u0005!\u0003\u0005C a\u0002\u0001\u0001\u0001\u0001\u0001\u0001\t\r\u0019\u0002\u0001\u0015!\u0004$\u0003\u0019bujR%T)&\u001buLU#H%\u0016\u001b6+S(O?\u0012+e)Q+M)~#\u0006JU#T\u0011>cE\t\t\u0005\bQ\u0001\u0011\r\u0011\"\u0011*\u0003\u0015iu\u000eZ3m+\u0005Q\u0003\u0003B\u00163iYi\u0011\u0001\f\u0006\u0003[9\n!a\u001c9\u000b\u0005\u001dy#B\u0001\u00192\u0003\u001d\u0019w.\u001c2vgRT\u0011!C\u0005\u0003g1\u0012qa\u00149N_\u0012,G\u000e\u0005\u0002\u0014k%\u0011aG\u0002\u0002\u0013'B\f'o\u001b\"v]\u0012dWmQ8oi\u0016DH\u000f\u0003\u00049\u0001\u0001\u0006IAK\u0001\u0007\u001b>$W\r\u001c\u0011\t\u000bi\u0002A\u0011I\u001e\u0002\u0013M\u0004\u0018M]6M_\u0006$G\u0003\u0002\f=\u0011BCQ!P\u001dA\u0002y\n1!^5e!\tyTI\u0004\u0002A\u00076\t\u0011IC\u0001C\u0003\u0015\u00198-\u00197b\u0013\t!\u0015)\u0001\u0004Qe\u0016$WMZ\u0005\u0003\r\u001e\u0013aa\u0015;sS:<'B\u0001#B\u0011\u0015I\u0015\b1\u0001K\u0003\u0015\u0019\b.\u00199f!\tYe*D\u0001M\u0015\tie&A\u0002eg2L!a\u0014'\u0003\u00139{G-Z*iCB,\u0007\"B):\u0001\u00041\u0012!B7pI\u0016d\u0007\"B*\u0001\t\u0003\"\u0016aC:qCJ\\\u0017J\u001c9viN$\"!\u00163\u0011\u0007Ys\u0016M\u0004\u0002X9:\u0011\u0001lW\u0007\u00023*\u0011!\fE\u0001\u0007yI|w\u000e\u001e \n\u0003\tK!!X!\u0002\u000fA\f7m[1hK&\u0011q\f\u0019\u0002\u0004'\u0016\f(BA/B!\t\u0019\"-\u0003\u0002d\r\tI\u0001+\u0019:b[N\u0003Xm\u0019\u0005\u0006KJ\u0003\rAF\u0001\u0004_\nT\u0007\"B4\u0001\t\u0003B\u0017\u0001D:qCJ\\w*\u001e;qkR\u001cHCA5n!\r1fL\u001b\t\u0003'-L!\u0001\u001c\u0004\u0003\u001fMKW\u000e\u001d7f!\u0006\u0014\u0018-\\*qK\u000eDQ!\u001a4A\u0002Y\u0001")
public class LogisticRegressionOp
extends SimpleSparkOp<LogisticRegressionModel> {
    private final double LOGISTIC_REGRESSION_DEFAULT_THRESHOLD;
    private final OpModel<SparkBundleContext, LogisticRegressionModel> Model = new OpModel<SparkBundleContext, LogisticRegressionModel>(this){
        private final Class<LogisticRegressionModel> klazz;

        public String modelOpName(Object obj, BundleContext context) {
            return OpModel$class.modelOpName(this, obj, context);
        }

        public Class<LogisticRegressionModel> klazz() {
            return this.klazz;
        }

        public String opName() {
            return Bundle$BuiltinOps$classification$.MODULE$.logistic_regression();
        }

        public Model store(Model model, LogisticRegressionModel obj, BundleContext<SparkBundleContext> context) {
            Model model2;
            Model m2 = (Model)model.withValue("num_classes", Value$.MODULE$.long(obj.numClasses()));
            if (obj.numClasses() > 2) {
                Matrix cm = obj.coefficientMatrix();
                None$ thresholds = obj.isSet((Param)obj.thresholds()) ? new Some((Object)obj.getThresholds()) : None$.MODULE$;
                model2 = (Model)((HasAttributes)((HasAttributes)m2.withValue("coefficient_matrix", Value$.MODULE$.tensor(new DenseTensor<T>(cm.toArray(), (Seq<Object>)((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{cm.numRows(), cm.numCols()}))), ClassTag$.MODULE$.Double())))).withValue("intercept_vector", Value$.MODULE$.vector(obj.interceptVector().toArray(), ClassTag$.MODULE$.Double()))).withValue("thresholds", (Option<Value>)thresholds.map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Seq<Object> apply(double[] x$1) {
                        return Predef$.MODULE$.doubleArrayOps(x$1).toSeq();
                    }
                }).map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Value apply(Seq<Object> value2) {
                        return Value$.MODULE$.doubleList(value2);
                    }
                }));
            } else {
                model2 = (Model)((HasAttributes)((HasAttributes)m2.withValue("coefficients", Value$.MODULE$.vector(obj.coefficients().toArray(), ClassTag$.MODULE$.Double()))).withValue("intercept", Value$.MODULE$.double(obj.intercept()))).withValue("threshold", Value$.MODULE$.double(obj.getThreshold()));
            }
            return model2;
        }

        public LogisticRegressionModel load(Model model, BundleContext<SparkBundleContext> context) {
            LogisticRegressionModel logisticRegressionModel;
            long numClasses = model.value("num_classes").getLong();
            if (numClasses > 2L) {
                Tensor<T> cmTensor = model.value("coefficient_matrix").getTensor();
                Matrix coefficientMatrix = Matrices$.MODULE$.dense(BoxesRunTime.unboxToInt((Object)cmTensor.dimensions().head()), BoxesRunTime.unboxToInt((Object)cmTensor.dimensions().apply(1)), (double[])cmTensor.toArray());
                LogisticRegressionModel lr = new LogisticRegressionModel("", coefficientMatrix, Vectors$.MODULE$.dense((double[])model.value("intercept_vector").getTensor().toArray()), (int)numClasses, true);
                logisticRegressionModel = (LogisticRegressionModel)model.getValue("thresholds").map((Function1)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$1;

                    public final LogisticRegressionModel apply(Value t2) {
                        return this.lr$1.setThresholds((double[])t2.getDoubleList().toArray(ClassTag$.MODULE$.Double()));
                    }
                    {
                        this.lr$1 = lr$1;
                    }
                }).getOrElse((Function0)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$1;

                    public final LogisticRegressionModel apply() {
                        return this.lr$1;
                    }
                    {
                        this.lr$1 = lr$1;
                    }
                });
            } else {
                LogisticRegressionModel lr = new LogisticRegressionModel("", Vectors$.MODULE$.dense((double[])model.value("coefficients").getTensor().toArray()), model.value("intercept").getDouble());
                double threshold2 = BoxesRunTime.unboxToDouble((Object)model.getValue("threshold").map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final double apply(Value value2) {
                        return value2.getDouble();
                    }
                }).getOrElse((Function0)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final double apply() {
                        return this.apply$mcD$sp();
                    }

                    public double apply$mcD$sp() {
                        return 0.5;
                    }
                }));
                logisticRegressionModel = lr.setThreshold(threshold2);
            }
            LogisticRegressionModel r = logisticRegressionModel;
            return r;
        }
        {
            OpModel$class.$init$(this);
            this.klazz = LogisticRegressionModel.class;
        }
    };

    private final double LOGISTIC_REGRESSION_DEFAULT_THRESHOLD() {
        return 0.5;
    }

    @Override
    public OpModel<SparkBundleContext, LogisticRegressionModel> Model() {
        return this.Model;
    }

    @Override
    public LogisticRegressionModel sparkLoad(String uid2, NodeShape shape2, LogisticRegressionModel model) {
        LogisticRegressionModel logisticRegressionModel;
        int numClasses = model.numClasses();
        if (numClasses > 2) {
            LogisticRegressionModel lr = new LogisticRegressionModel(uid2, model.coefficientMatrix(), model.interceptVector(), numClasses, true);
            Object object = model.isDefined((Param)model.thresholds()) ? lr.setThresholds(model.getThresholds()) : BoxedUnit.UNIT;
            logisticRegressionModel = lr;
        } else {
            LogisticRegressionModel lr = new LogisticRegressionModel(uid2, model.coefficientMatrix(), model.interceptVector(), numClasses, false);
            Object object = model.isDefined((Param)model.threshold()) ? lr.setThreshold(model.getThreshold()) : BoxedUnit.UNIT;
            logisticRegressionModel = lr;
        }
        LogisticRegressionModel r = logisticRegressionModel;
        return r;
    }

    @Override
    public Seq<ParamSpec> sparkInputs(LogisticRegressionModel obj) {
        return (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new SimpleParamSpec[]{ParamSpec$.MODULE$.apply((Tuple2<String, Param<String>>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"features"), (Object)obj.featuresCol()))}));
    }

    @Override
    public Seq<SimpleParamSpec> sparkOutputs(LogisticRegressionModel obj) {
        return (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new SimpleParamSpec[]{ParamSpec$.MODULE$.apply((Tuple2<String, Param<String>>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"raw_prediction"), (Object)obj.rawPredictionCol())), ParamSpec$.MODULE$.apply((Tuple2<String, Param<String>>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"probability"), (Object)obj.probabilityCol())), ParamSpec$.MODULE$.apply((Tuple2<String, Param<String>>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"prediction"), (Object)obj.predictionCol()))}));
    }

    public LogisticRegressionOp() {
        super(ClassTag$.MODULE$.apply(LogisticRegressionModel.class));
    }
}

