/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.bundle.ops.classification;

import ml.bundle.Socket.Socket;
import ml.combust.bundle.BundleContext;
import ml.combust.bundle.dsl.Bundle;
import ml.combust.bundle.dsl.HasAttributeList;
import ml.combust.bundle.dsl.Model;
import ml.combust.bundle.dsl.Node;
import ml.combust.bundle.dsl.Shape;
import ml.combust.bundle.dsl.Shape$;
import ml.combust.bundle.dsl.Value;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.op.OpModel;
import ml.combust.bundle.op.OpNode;
import org.apache.spark.ml.bundle.SparkBundleContext;
import org.apache.spark.ml.bundle.util.ParamUtil$;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.Params;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.ObjectRef;

@ScalaSignature(bytes="\u0006\u0001%4A!\u0001\u0002\u0001#\t9Bj\\4jgRL7MU3he\u0016\u001c8/[8o\u001fB4&\u0007\r\u0006\u0003\u0007\u0011\tab\u00197bgNLg-[2bi&|gN\u0003\u0002\u0006\r\u0005\u0019q\u000e]:\u000b\u0005\u001dA\u0011A\u00022v]\u0012dWM\u0003\u0002\n\u0015\u0005\u0011Q\u000e\u001c\u0006\u0003\u00171\tQa\u001d9be.T!!\u0004\b\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005y\u0011aA8sO\u000e\u00011c\u0001\u0001\u00131A\u00111CF\u0007\u0002))\tQ#A\u0003tG\u0006d\u0017-\u0003\u0002\u0018)\t1\u0011I\\=SK\u001a\u0004R!\u0007\u0011#M\u0019j\u0011A\u0007\u0006\u00037q\t!a\u001c9\u000b\u0005\u001di\"B\u0001\u0010 \u0003\u001d\u0019w.\u001c2vgRT\u0011!C\u0005\u0003Ci\u0011aa\u00149O_\u0012,\u0007CA\u0012%\u001b\u00051\u0011BA\u0013\u0007\u0005I\u0019\u0006/\u0019:l\u0005VtG\r\\3D_:$X\r\u001f;\u0011\u0005\u001dJS\"\u0001\u0015\u000b\u0005\rA\u0011B\u0001\u0016)\u0005]aunZ5ti&\u001c'+Z4sKN\u001c\u0018n\u001c8N_\u0012,G\u000eC\u0003-\u0001\u0011\u0005Q&\u0001\u0004=S:LGO\u0010\u000b\u0002]A\u0011q\u0006A\u0007\u0002\u0005!9\u0011\u0007\u0001b\u0001\n\u0003\u0012\u0014!B'pI\u0016dW#A\u001a\u0011\te!$EJ\u0005\u0003ki\u0011qa\u00149N_\u0012,G\u000e\u0003\u00048\u0001\u0001\u0006IaM\u0001\u0007\u001b>$W\r\u001c\u0011\t\u000fe\u0002!\u0019!C!u\u0005)1\u000e\\1{uV\t1\bE\u0002=\u007f\u0019r!aE\u001f\n\u0005y\"\u0012A\u0002)sK\u0012,g-\u0003\u0002A\u0003\n)1\t\\1tg*\u0011a\b\u0006\u0005\u0007\u0007\u0002\u0001\u000b\u0011B\u001e\u0002\r-d\u0017M\u001f>!\u0011\u0015)\u0005\u0001\"\u0011G\u0003\u0011q\u0017-\\3\u0015\u0005\u001dS\u0005C\u0001\u001fI\u0013\tI\u0015I\u0001\u0004TiJLgn\u001a\u0005\u0006\u0017\u0012\u0003\rAJ\u0001\u0005]>$W\rC\u0003N\u0001\u0011\u0005c*A\u0003n_\u0012,G\u000e\u0006\u0002'\u001f\")1\n\u0014a\u0001M!)\u0011\u000b\u0001C!%\u0006!An\\1e)\r\u0019&,\u0019\u000b\u0003MQCQ!\u0016)A\u0004Y\u000bqaY8oi\u0016DH\u000fE\u0002X1\nj\u0011\u0001H\u0005\u00033r\u0011QBQ;oI2,7i\u001c8uKb$\b\"B&Q\u0001\u0004Y\u0006C\u0001/`\u001b\u0005i&B\u00010\u001d\u0003\r!7\u000f\\\u0005\u0003Av\u0013AAT8eK\")Q\n\u0015a\u0001M!)1\r\u0001C!I\u0006)1\u000f[1qKR\u0011Q\r\u001b\t\u00039\u001aL!aZ/\u0003\u000bMC\u0017\r]3\t\u000b-\u0013\u0007\u0019\u0001\u0014")
public class LogisticRegressionOpV20
implements OpNode<SparkBundleContext, LogisticRegressionModel, LogisticRegressionModel> {
    private final OpModel<SparkBundleContext, LogisticRegressionModel> Model;
    private final Class<LogisticRegressionModel> klazz;

    public Option children(Object node) {
        return OpNode.class.children((OpNode)this, (Object)node);
    }

    public OpModel<SparkBundleContext, LogisticRegressionModel> Model() {
        return this.Model;
    }

    public Class<LogisticRegressionModel> klazz() {
        return this.klazz;
    }

    public String name(LogisticRegressionModel node) {
        return node.uid();
    }

    public LogisticRegressionModel model(LogisticRegressionModel node) {
        return node;
    }

    public LogisticRegressionModel load(Node node, LogisticRegressionModel model, BundleContext<SparkBundleContext> context) {
        ObjectRef lr = ObjectRef.create((Object)((LogisticRegressionModel)new LogisticRegressionModel(node.name(), model.coefficients(), model.intercept()).setFeaturesCol(node.shape().input("features").name()).setPredictionCol(node.shape().output("prediction").name())));
        ParamUtil$.MODULE$.setOptional((Params)((LogisticRegressionModel)lr.elem), (Params)model, (Param)((LogisticRegressionModel)lr.elem).threshold(), (Param)model.threshold());
        lr.elem = (LogisticRegressionModel)node.shape().getOutput("probability").map((Function1)new Serializable(this, lr){
            public static final long serialVersionUID = 0L;
            private final ObjectRef lr$2;

            public final LogisticRegressionModel apply(Socket p) {
                return (LogisticRegressionModel)((LogisticRegressionModel)this.lr$2.elem).setProbabilityCol(p.name());
            }
            {
                this.lr$2 = lr$2;
            }
        }).getOrElse((Function0)new Serializable(this, lr){
            public static final long serialVersionUID = 0L;
            private final ObjectRef lr$2;

            public final LogisticRegressionModel apply() {
                return (LogisticRegressionModel)this.lr$2.elem;
            }
            {
                this.lr$2 = lr$2;
            }
        });
        return (LogisticRegressionModel)node.shape().getOutput("raw_prediction").map((Function1)new Serializable(this, lr){
            public static final long serialVersionUID = 0L;
            private final ObjectRef lr$2;

            public final LogisticRegressionModel apply(Socket rp) {
                return (LogisticRegressionModel)((LogisticRegressionModel)this.lr$2.elem).setRawPredictionCol(rp.name());
            }
            {
                this.lr$2 = lr$2;
            }
        }).getOrElse((Function0)new Serializable(this, lr){
            public static final long serialVersionUID = 0L;
            private final ObjectRef lr$2;

            public final LogisticRegressionModel apply() {
                return (LogisticRegressionModel)this.lr$2.elem;
            }
            {
                this.lr$2 = lr$2;
            }
        });
    }

    public Shape shape(LogisticRegressionModel node) {
        None$ rawPrediction = node.isDefined(node.rawPredictionCol()) ? new Some((Object)node.getRawPredictionCol()) : None$.MODULE$;
        None$ probability = node.isDefined(node.probabilityCol()) ? new Some((Object)node.getProbabilityCol()) : None$.MODULE$;
        return Shape$.MODULE$.apply().withInput(node.getFeaturesCol(), "features").withOutput(node.getPredictionCol(), "prediction").withOutput((Option)rawPrediction, "raw_prediction").withOutput((Option)probability, "probability");
    }

    public LogisticRegressionOpV20() {
        OpNode.class.$init$((OpNode)this);
        this.Model = new OpModel<SparkBundleContext, LogisticRegressionModel>(this){
            private final Class<LogisticRegressionModel> klazz;

            public Class<LogisticRegressionModel> klazz() {
                return this.klazz;
            }

            public String opName() {
                return Bundle.BuiltinOps$.classification$.MODULE$.logistic_regression();
            }

            public Model store(Model model, LogisticRegressionModel obj, BundleContext<SparkBundleContext> context) {
                Predef$.MODULE$.assert(obj.numClasses() == 2, (Function0)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final String apply() {
                        return "This op only supports binary logistic regression";
                    }
                });
                Model m = (Model)model.withAttr("num_classes", Value$.MODULE$.long((long)obj.numClasses()));
                return (Model)((HasAttributeList)((HasAttributeList)m.withAttr("coefficients", Value$.MODULE$.vector((Object)obj.coefficients().toArray(), ClassTag$.MODULE$.Double()))).withAttr("intercept", Value$.MODULE$.double(obj.intercept()))).withAttr("threshold", Value$.MODULE$.double(obj.getThreshold()));
            }

            public LogisticRegressionModel load(Model model, BundleContext<SparkBundleContext> context) {
                if (model.value("num_classes").getLong() != 2L) {
                    throw new IllegalArgumentException("Only binary logistic regression supported in Spark");
                }
                LogisticRegressionModel lr = new LogisticRegressionModel("", Vectors$.MODULE$.dense((double[])model.value("coefficients").getTensor().toArray()), model.value("intercept").getDouble());
                return (LogisticRegressionModel)model.getValue("threshold").map((Function1)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$1;

                    public final LogisticRegressionModel apply(Value t) {
                        return this.lr$1.setThreshold(t.getDouble());
                    }
                    {
                        this.lr$1 = lr$1;
                    }
                }).getOrElse((Function0)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$1;

                    public final LogisticRegressionModel apply() {
                        return this.lr$1;
                    }
                    {
                        this.lr$1 = lr$1;
                    }
                });
            }
            {
                this.klazz = LogisticRegressionModel.class;
            }
        };
        this.klazz = LogisticRegressionModel.class;
    }
}

