/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.bundle.ops.classification;

import ml.bundle.Socket.Socket;
import ml.combust.bundle.BundleContext;
import ml.combust.bundle.dsl.Bundle;
import ml.combust.bundle.dsl.HasAttributeList;
import ml.combust.bundle.dsl.Model;
import ml.combust.bundle.dsl.Node;
import ml.combust.bundle.dsl.Shape;
import ml.combust.bundle.dsl.Shape$;
import ml.combust.bundle.dsl.Value;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.op.OpModel;
import ml.combust.bundle.op.OpNode;
import org.apache.spark.ml.bundle.SparkBundleContext;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.Param;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.collection.Seq;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.ObjectRef;

@ScalaSignature(bytes="\u0006\u0001%4A!\u0001\u0002\u0001#\t!Bj\\4jgRL7MU3he\u0016\u001c8/[8o\u001fBT!a\u0001\u0003\u0002\u001d\rd\u0017m]:jM&\u001c\u0017\r^5p]*\u0011QAB\u0001\u0004_B\u001c(BA\u0004\t\u0003\u0019\u0011WO\u001c3mK*\u0011\u0011BC\u0001\u0003[2T!a\u0003\u0007\u0002\u000bM\u0004\u0018M]6\u000b\u00055q\u0011AB1qC\u000eDWMC\u0001\u0010\u0003\ry'oZ\u0002\u0001'\r\u0001!\u0003\u0007\t\u0003'Yi\u0011\u0001\u0006\u0006\u0002+\u0005)1oY1mC&\u0011q\u0003\u0006\u0002\u0007\u0003:L(+\u001a4\u0011\u000be\u0001#E\n\u0014\u000e\u0003iQ!a\u0007\u000f\u0002\u0005=\u0004(BA\u0004\u001e\u0015\tqr$A\u0004d_6\u0014Wo\u001d;\u000b\u0003%I!!\t\u000e\u0003\r=\u0003hj\u001c3f!\t\u0019C%D\u0001\u0007\u0013\t)cA\u0001\nTa\u0006\u00148NQ;oI2,7i\u001c8uKb$\bCA\u0014*\u001b\u0005A#BA\u0002\t\u0013\tQ\u0003FA\fM_\u001eL7\u000f^5d%\u0016<'/Z:tS>tWj\u001c3fY\")A\u0006\u0001C\u0001[\u00051A(\u001b8jiz\"\u0012A\f\t\u0003_\u0001i\u0011A\u0001\u0005\bc\u0001\u0011\r\u0011\"\u00113\u0003\u0015iu\u000eZ3m+\u0005\u0019\u0004\u0003B\r5E\u0019J!!\u000e\u000e\u0003\u000f=\u0003Xj\u001c3fY\"1q\u0007\u0001Q\u0001\nM\na!T8eK2\u0004\u0003bB\u001d\u0001\u0005\u0004%\tEO\u0001\u0006W2\f'P_\u000b\u0002wA\u0019Ah\u0010\u0014\u000f\u0005Mi\u0014B\u0001 \u0015\u0003\u0019\u0001&/\u001a3fM&\u0011\u0001)\u0011\u0002\u0006\u00072\f7o\u001d\u0006\u0003}QAaa\u0011\u0001!\u0002\u0013Y\u0014AB6mCjT\b\u0005C\u0003F\u0001\u0011\u0005c)\u0001\u0003oC6,GCA$K!\ta\u0004*\u0003\u0002J\u0003\n11\u000b\u001e:j]\u001eDQa\u0013#A\u0002\u0019\nAA\\8eK\")Q\n\u0001C!\u001d\u0006)Qn\u001c3fYR\u0011ae\u0014\u0005\u0006\u00172\u0003\rA\n\u0005\u0006#\u0002!\tEU\u0001\u0005Y>\fG\rF\u0002T5\u0006$\"A\n+\t\u000bU\u0003\u00069\u0001,\u0002\u000f\r|g\u000e^3yiB\u0019q\u000b\u0017\u0012\u000e\u0003qI!!\u0017\u000f\u0003\u001b\t+h\u000e\u001a7f\u0007>tG/\u001a=u\u0011\u0015Y\u0005\u000b1\u0001\\!\tav,D\u0001^\u0015\tqF$A\u0002eg2L!\u0001Y/\u0003\t9{G-\u001a\u0005\u0006\u001bB\u0003\rA\n\u0005\u0006G\u0002!\t\u0005Z\u0001\u0006g\"\f\u0007/\u001a\u000b\u0003K\"\u0004\"\u0001\u00184\n\u0005\u001dl&!B*iCB,\u0007\"B&c\u0001\u00041\u0003")
public class LogisticRegressionOp
implements OpNode<SparkBundleContext, LogisticRegressionModel, LogisticRegressionModel> {
    private final OpModel<SparkBundleContext, LogisticRegressionModel> Model;
    private final Class<LogisticRegressionModel> klazz;

    public Option children(Object node) {
        return OpNode.class.children((OpNode)this, (Object)node);
    }

    public OpModel<SparkBundleContext, LogisticRegressionModel> Model() {
        return this.Model;
    }

    public Class<LogisticRegressionModel> klazz() {
        return this.klazz;
    }

    public String name(LogisticRegressionModel node) {
        return node.uid();
    }

    public LogisticRegressionModel model(LogisticRegressionModel node) {
        return node;
    }

    public LogisticRegressionModel load(Node node, LogisticRegressionModel model, BundleContext<SparkBundleContext> context) {
        ObjectRef lr = ObjectRef.create((Object)((LogisticRegressionModel)new LogisticRegressionModel(node.name(), model.coefficients(), model.intercept()).copy(model.extractParamMap()).setFeaturesCol(node.shape().input("features").name()).setPredictionCol(node.shape().output("prediction").name())));
        lr.elem = (LogisticRegressionModel)node.shape().getOutput("probability").map((Function1)new Serializable(this, lr){
            public static final long serialVersionUID = 0L;
            private final ObjectRef lr$2;

            public final LogisticRegressionModel apply(Socket p) {
                return (LogisticRegressionModel)((LogisticRegressionModel)this.lr$2.elem).setProbabilityCol(p.name());
            }
            {
                this.lr$2 = lr$2;
            }
        }).getOrElse((Function0)new Serializable(this, lr){
            public static final long serialVersionUID = 0L;
            private final ObjectRef lr$2;

            public final LogisticRegressionModel apply() {
                return (LogisticRegressionModel)this.lr$2.elem;
            }
            {
                this.lr$2 = lr$2;
            }
        });
        return (LogisticRegressionModel)node.shape().getOutput("raw_prediction").map((Function1)new Serializable(this, lr){
            public static final long serialVersionUID = 0L;
            private final ObjectRef lr$2;

            public final LogisticRegressionModel apply(Socket rp) {
                return (LogisticRegressionModel)((LogisticRegressionModel)this.lr$2.elem).setRawPredictionCol(rp.name());
            }
            {
                this.lr$2 = lr$2;
            }
        }).getOrElse((Function0)new Serializable(this, lr){
            public static final long serialVersionUID = 0L;
            private final ObjectRef lr$2;

            public final LogisticRegressionModel apply() {
                return (LogisticRegressionModel)this.lr$2.elem;
            }
            {
                this.lr$2 = lr$2;
            }
        });
    }

    public Shape shape(LogisticRegressionModel node) {
        None$ rawPrediction = node.isDefined(node.rawPredictionCol()) ? new Some((Object)node.getRawPredictionCol()) : None$.MODULE$;
        None$ probability = node.isDefined(node.probabilityCol()) ? new Some((Object)node.getProbabilityCol()) : None$.MODULE$;
        return Shape$.MODULE$.apply().withInput(node.getFeaturesCol(), "features").withOutput(node.getPredictionCol(), "prediction").withOutput((Option)rawPrediction, "raw_prediction").withOutput((Option)probability, "probability");
    }

    public LogisticRegressionOp() {
        OpNode.class.$init$((OpNode)this);
        this.Model = new OpModel<SparkBundleContext, LogisticRegressionModel>(this){
            private final Class<LogisticRegressionModel> klazz;

            public Class<LogisticRegressionModel> klazz() {
                return this.klazz;
            }

            public String opName() {
                return Bundle.BuiltinOps$.classification$.MODULE$.logistic_regression();
            }

            public Model store(Model model, LogisticRegressionModel obj, BundleContext<SparkBundleContext> context) {
                return (Model)((HasAttributeList)((HasAttributeList)((HasAttributeList)model.withAttr("coefficients", Value$.MODULE$.doubleVector((Seq)Predef$.MODULE$.wrapDoubleArray(obj.coefficients().toArray())))).withAttr("intercept", Value$.MODULE$.double(obj.intercept()))).withAttr("num_classes", Value$.MODULE$.long((long)obj.numClasses()))).withAttr("threshold", obj.get((Param)obj.threshold()).map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Value apply(double value) {
                        return Value$.MODULE$.double(value);
                    }
                }));
            }

            public LogisticRegressionModel load(Model model, BundleContext<SparkBundleContext> context) {
                if (model.value("num_classes").getLong() != 2L) {
                    throw new IllegalArgumentException("Only binary logistic regression supported in Spark");
                }
                LogisticRegressionModel lr = new LogisticRegressionModel("", Vectors$.MODULE$.dense((double[])model.value("coefficients").getDoubleVector().toArray(ClassTag$.MODULE$.Double())), model.value("intercept").getDouble());
                return (LogisticRegressionModel)model.getValue("threshold").map((Function1)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$1;

                    public final LogisticRegressionModel apply(Value t) {
                        return this.lr$1.setThreshold(t.getDouble());
                    }
                    {
                        this.lr$1 = lr$1;
                    }
                }).getOrElse((Function0)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$1;

                    public final LogisticRegressionModel apply() {
                        return this.lr$1;
                    }
                    {
                        this.lr$1 = lr$1;
                    }
                });
            }
            {
                this.klazz = LogisticRegressionModel.class;
            }
        };
        this.klazz = LogisticRegressionModel.class;
    }
}

