/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.bundle.ops.classification;

import ml.combust.bundle.BundleContext;
import ml.combust.bundle.dsl.Bundle;
import ml.combust.bundle.dsl.HasAttributes;
import ml.combust.bundle.dsl.Model;
import ml.combust.bundle.dsl.NodeShape;
import ml.combust.bundle.dsl.Value;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.op.OpModel;
import ml.combust.mleap.tensor.DenseTensor;
import ml.combust.mleap.tensor.Tensor;
import org.apache.spark.ml.bundle.ParamSpec;
import org.apache.spark.ml.bundle.ParamSpec$;
import org.apache.spark.ml.bundle.SimpleParamSpec;
import org.apache.spark.ml.bundle.SimpleSparkOp;
import org.apache.spark.ml.bundle.SparkBundleContext;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Matrices$;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.Param;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u001d4A!\u0001\u0002\u0001#\t9Bj\\4jgRL7MU3he\u0016\u001c8/[8o\u001fB4&'\r\u0006\u0003\u0007\u0011\tab\u00197bgNLg-[2bi&|gN\u0003\u0002\u0006\r\u0005\u0019q\u000e]:\u000b\u0005\u001dA\u0011A\u00022v]\u0012dWM\u0003\u0002\n\u0015\u0005\u0011Q\u000e\u001c\u0006\u0003\u00171\tQa\u001d9be.T!!\u0004\b\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005y\u0011aA8sO\u000e\u00011C\u0001\u0001\u0013!\r\u0019BCF\u0007\u0002\r%\u0011QC\u0002\u0002\u000e'&l\u0007\u000f\\3Ta\u0006\u00148n\u00149\u0011\u0005]IR\"\u0001\r\u000b\u0005\rA\u0011B\u0001\u000e\u0019\u0005]aunZ5ti&\u001c'+Z4sKN\u001c\u0018n\u001c8N_\u0012,G\u000eC\u0003\u001d\u0001\u0011\u0005Q$\u0001\u0004=S:LGO\u0010\u000b\u0002=A\u0011q\u0004A\u0007\u0002\u0005!9\u0011\u0005\u0001b\u0001\n\u0003\u0012\u0013!B'pI\u0016dW#A\u0012\u0011\t\u0011ZSFF\u0007\u0002K)\u0011aeJ\u0001\u0003_BT!a\u0002\u0015\u000b\u0005%R\u0013aB2p[\n,8\u000f\u001e\u0006\u0002\u0013%\u0011A&\n\u0002\b\u001fBlu\u000eZ3m!\t\u0019b&\u0003\u00020\r\t\u00112\u000b]1sW\n+h\u000e\u001a7f\u0007>tG/\u001a=u\u0011\u0019\t\u0004\u0001)A\u0005G\u00051Qj\u001c3fY\u0002BQa\r\u0001\u0005BQ\n\u0011b\u001d9be.du.\u00193\u0015\tY)\u0014)\u0013\u0005\u0006mI\u0002\raN\u0001\u0004k&$\u0007C\u0001\u001d?\u001d\tID(D\u0001;\u0015\u0005Y\u0014!B:dC2\f\u0017BA\u001f;\u0003\u0019\u0001&/\u001a3fM&\u0011q\b\u0011\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005uR\u0004\"\u0002\"3\u0001\u0004\u0019\u0015!B:iCB,\u0007C\u0001#H\u001b\u0005)%B\u0001$(\u0003\r!7\u000f\\\u0005\u0003\u0011\u0016\u0013\u0011BT8eKNC\u0017\r]3\t\u000b)\u0013\u0004\u0019\u0001\f\u0002\u000b5|G-\u001a7\t\u000b1\u0003A\u0011I'\u0002\u0017M\u0004\u0018M]6J]B,Ho\u001d\u000b\u0003\u001dv\u00032aT,[\u001d\t\u0001VK\u0004\u0002R)6\t!K\u0003\u0002T!\u00051AH]8pizJ\u0011aO\u0005\u0003-j\nq\u0001]1dW\u0006<W-\u0003\u0002Y3\n\u00191+Z9\u000b\u0005YS\u0004CA\n\\\u0013\tafAA\u0005QCJ\fWn\u00159fG\")al\u0013a\u0001-\u0005\u0019qN\u00196\t\u000b\u0001\u0004A\u0011I1\u0002\u0019M\u0004\u0018M]6PkR\u0004X\u000f^:\u0015\u0005\t4\u0007cA(XGB\u00111\u0003Z\u0005\u0003K\u001a\u0011qbU5na2,\u0007+\u0019:b[N\u0003Xm\u0019\u0005\u0006=~\u0003\rA\u0006")
public class LogisticRegressionOpV21
extends SimpleSparkOp<LogisticRegressionModel> {
    private final OpModel<SparkBundleContext, LogisticRegressionModel> Model = new OpModel<SparkBundleContext, LogisticRegressionModel>(this){
        private final Class<LogisticRegressionModel> klazz;

        public Class<LogisticRegressionModel> klazz() {
            return this.klazz;
        }

        public String opName() {
            return Bundle.BuiltinOps$.classification$.MODULE$.logistic_regression();
        }

        public Model store(Model model, LogisticRegressionModel obj, BundleContext<SparkBundleContext> context) {
            Model model2;
            Model m = (Model)model.withValue("num_classes", Value$.MODULE$.long((long)obj.numClasses()));
            if (obj.numClasses() > 2) {
                Matrix cm = obj.coefficientMatrix();
                None$ thresholds = obj.isSet((Param)obj.thresholds()) ? new Some((Object)obj.getThresholds()) : None$.MODULE$;
                model2 = (Model)((HasAttributes)((HasAttributes)m.withValue("coefficient_matrix", Value$.MODULE$.tensor((Tensor)new DenseTensor((Object)cm.toArray(), (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{cm.numRows(), cm.numCols()})), ClassTag$.MODULE$.Double())))).withValue("intercept_vector", Value$.MODULE$.vector((Object)obj.interceptVector().toArray(), ClassTag$.MODULE$.Double()))).withValue("thresholds", thresholds.map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Seq<Object> apply(double[] x$1) {
                        return Predef$.MODULE$.doubleArrayOps(x$1).toSeq();
                    }
                }).map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Value apply(Seq<Object> value) {
                        return Value$.MODULE$.doubleList(value);
                    }
                }));
            } else {
                model2 = (Model)((HasAttributes)((HasAttributes)m.withValue("coefficients", Value$.MODULE$.vector((Object)obj.coefficients().toArray(), ClassTag$.MODULE$.Double()))).withValue("intercept", Value$.MODULE$.double(obj.intercept()))).withValue("threshold", Value$.MODULE$.double(obj.getThreshold()));
            }
            return model2;
        }

        public LogisticRegressionModel load(Model model, BundleContext<SparkBundleContext> context) {
            LogisticRegressionModel logisticRegressionModel;
            long numClasses = model.value("num_classes").getLong();
            if (numClasses > 2L) {
                Tensor cmTensor = model.value("coefficient_matrix").getTensor();
                Matrix coefficientMatrix = Matrices$.MODULE$.dense(BoxesRunTime.unboxToInt((Object)cmTensor.dimensions().head()), BoxesRunTime.unboxToInt((Object)cmTensor.dimensions().apply(1)), (double[])cmTensor.toArray());
                LogisticRegressionModel lr = new LogisticRegressionModel("", coefficientMatrix, Vectors$.MODULE$.dense((double[])model.value("intercept").getTensor().toArray()), (int)numClasses, true);
                logisticRegressionModel = (LogisticRegressionModel)model.getValue("thresholds").map((Function1)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$1;

                    public final LogisticRegressionModel apply(Value t) {
                        return this.lr$1.setThresholds((double[])t.getDoubleList().toArray(ClassTag$.MODULE$.Double()));
                    }
                    {
                        this.lr$1 = lr$1;
                    }
                }).getOrElse((Function0)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$1;

                    public final LogisticRegressionModel apply() {
                        return this.lr$1;
                    }
                    {
                        this.lr$1 = lr$1;
                    }
                });
            } else {
                LogisticRegressionModel lr = new LogisticRegressionModel("", Vectors$.MODULE$.dense((double[])model.value("coefficients").getTensor().toArray()), model.value("intercept").getDouble());
                logisticRegressionModel = (LogisticRegressionModel)model.getValue("threshold").map((Function1)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$2;

                    public final LogisticRegressionModel apply(Value t) {
                        return this.lr$2.setThreshold(t.getDouble());
                    }
                    {
                        this.lr$2 = lr$2;
                    }
                }).getOrElse((Function0)new Serializable(this, lr){
                    public static final long serialVersionUID = 0L;
                    private final LogisticRegressionModel lr$2;

                    public final LogisticRegressionModel apply() {
                        return this.lr$2;
                    }
                    {
                        this.lr$2 = lr$2;
                    }
                });
            }
            LogisticRegressionModel r = logisticRegressionModel;
            return r;
        }
        {
            this.klazz = LogisticRegressionModel.class;
        }
    };

    public OpModel<SparkBundleContext, LogisticRegressionModel> Model() {
        return this.Model;
    }

    public LogisticRegressionModel sparkLoad(String uid, NodeShape shape, LogisticRegressionModel model) {
        LogisticRegressionModel logisticRegressionModel;
        int numClasses = model.numClasses();
        if (numClasses > 2) {
            LogisticRegressionModel lr = new LogisticRegressionModel(uid, model.coefficientMatrix(), model.interceptVector(), numClasses, true);
            Object object = lr.isDefined((Param)lr.thresholds()) ? lr.setThresholds(model.getThresholds()) : BoxedUnit.UNIT;
            logisticRegressionModel = lr;
        } else {
            LogisticRegressionModel lr = new LogisticRegressionModel(uid, model.coefficientMatrix(), model.interceptVector(), numClasses, false);
            Object object = lr.isDefined((Param)lr.threshold()) ? lr.setThreshold(model.getThreshold()) : BoxedUnit.UNIT;
            logisticRegressionModel = lr;
        }
        LogisticRegressionModel r = logisticRegressionModel;
        return r;
    }

    public Seq<ParamSpec> sparkInputs(LogisticRegressionModel obj) {
        return (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new SimpleParamSpec[]{ParamSpec$.MODULE$.apply(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"features"), (Object)obj.featuresCol()))}));
    }

    public Seq<SimpleParamSpec> sparkOutputs(LogisticRegressionModel obj) {
        return (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new SimpleParamSpec[]{ParamSpec$.MODULE$.apply(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"raw_prediction"), (Object)obj.rawPredictionCol())), ParamSpec$.MODULE$.apply(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"probability"), (Object)obj.probabilityCol())), ParamSpec$.MODULE$.apply(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"prediction"), (Object)obj.predictionCol()))}));
    }

    public LogisticRegressionOpV21() {
        super(ClassTag$.MODULE$.apply(LogisticRegressionModel.class));
    }
}

