/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.bundle.ops.classification;

import ml.bundle.Socket.Socket;
import ml.combust.bundle.dsl.Attribute;
import ml.combust.bundle.dsl.Bundle;
import ml.combust.bundle.dsl.ReadableModel;
import ml.combust.bundle.dsl.ReadableNode;
import ml.combust.bundle.dsl.Shape;
import ml.combust.bundle.dsl.Shape$;
import ml.combust.bundle.dsl.Value;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.dsl.WritableModel;
import ml.combust.bundle.op.OpModel;
import ml.combust.bundle.op.OpNode;
import ml.combust.bundle.serializer.BundleContext;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.Param;
import scala.Function0;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.collection.Seq;
import scala.reflect.ClassTag$;

public final class LogisticRegressionOp$
implements OpNode<LogisticRegressionModel, LogisticRegressionModel> {
    public static final LogisticRegressionOp$ MODULE$;
    private final OpModel<LogisticRegressionModel> Model;

    static {
        new LogisticRegressionOp$();
    }

    public OpModel<LogisticRegressionModel> Model() {
        return this.Model;
    }

    public String name(LogisticRegressionModel node) {
        return node.uid();
    }

    public LogisticRegressionModel model(LogisticRegressionModel node) {
        return node;
    }

    public LogisticRegressionModel load(BundleContext context, ReadableNode node, LogisticRegressionModel model) {
        LogisticRegressionModel lr = (LogisticRegressionModel)new LogisticRegressionModel(node.name(), model.coefficients(), model.intercept()).copy(model.extractParamMap()).setFeaturesCol(node.shape().input("features").name()).setPredictionCol(node.shape().output("prediction").name());
        return (LogisticRegressionModel)node.shape().getOutput("probability").map((Function1)new Serializable(lr){
            public static final long serialVersionUID = 0L;
            private final LogisticRegressionModel lr$2;

            public final LogisticRegressionModel apply(Socket p) {
                return (LogisticRegressionModel)this.lr$2.setProbabilityCol(p.name());
            }
            {
                this.lr$2 = lr$2;
            }
        }).getOrElse((Function0)new Serializable(lr){
            public static final long serialVersionUID = 0L;
            private final LogisticRegressionModel lr$2;

            public final LogisticRegressionModel apply() {
                return this.lr$2;
            }
            {
                this.lr$2 = lr$2;
            }
        });
    }

    public Shape shape(LogisticRegressionModel node) {
        Shape s = new Shape(Shape$.MODULE$.apply$default$1()).withInput(node.getFeaturesCol(), "features").withOutput(node.getPredictionCol(), "prediction");
        return (Shape)node.get(node.probabilityCol()).map((Function1)new Serializable(s){
            public static final long serialVersionUID = 0L;
            private final Shape s$1;

            public final Shape apply(String p) {
                return this.s$1.withOutput(p, "probability");
            }
            {
                this.s$1 = s$1;
            }
        }).getOrElse((Function0)new Serializable(s){
            public static final long serialVersionUID = 0L;
            private final Shape s$1;

            public final Shape apply() {
                return this.s$1;
            }
            {
                this.s$1 = s$1;
            }
        });
    }

    private LogisticRegressionOp$() {
        MODULE$ = this;
        this.Model = new OpModel<LogisticRegressionModel>(){

            public String opName() {
                return Bundle.BuiltinOps$.classification$.MODULE$.logistic_regression();
            }

            public WritableModel store(BundleContext context, WritableModel model, LogisticRegressionModel obj) {
                WritableModel m = model.withAttr(new Attribute("coefficients", Value$.MODULE$.doubleVector((Seq)Predef$.MODULE$.wrapDoubleArray(obj.coefficients().toArray())))).withAttr(new Attribute("intercept", Value$.MODULE$.double(obj.intercept()))).withAttr(new Attribute("num_classes", Value$.MODULE$.long((long)obj.numClasses())));
                return (WritableModel)obj.get((Param)obj.threshold()).map((Function1)new Serializable(this, m){
                    public static final long serialVersionUID = 0L;
                    private final WritableModel m$1;

                    public final WritableModel apply(double t) {
                        return this.m$1.withAttr(new Attribute("threshold", Value$.MODULE$.double(t)));
                    }
                    {
                        this.m$1 = m$1;
                    }
                }).getOrElse((Function0)new Serializable(this, m){
                    public static final long serialVersionUID = 0L;
                    private final WritableModel m$1;

                    public final WritableModel apply() {
                        return this.m$1;
                    }
                    {
                        this.m$1 = m$1;
                    }
                });
            }

            public LogisticRegressionModel load(BundleContext context, ReadableModel model) {
                if (model.value("num_classes").getLong() != 2L) {
                    throw new Error("Only binary logistic regression supported in Spark");
                }
                LogisticRegressionModel lr = new LogisticRegressionModel("", Vectors$.MODULE$.dense((double[])model.value("coefficients").getDoubleVector().toArray(ClassTag$.MODULE$.Double())), model.value("intercept").getDouble());
                return (LogisticRegressionModel)model.getValue("threshold").map((Function1)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$1;

                    public final LogisticRegressionModel apply(Value t) {
                        return this.lr$1.setThreshold(t.getDouble());
                    }
                    {
                        this.lr$1 = lr$1;
                    }
                }).getOrElse((Function0)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$1;

                    public final LogisticRegressionModel apply() {
                        return this.lr$1;
                    }
                    {
                        this.lr$1 = lr$1;
                    }
                });
            }
        };
    }
}

