/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.bundle.ops.classification;

import java.io.Serializable;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.SerializedLambda;
import ml.combust.bundle.BundleContext;
import ml.combust.bundle.dsl.Bundle;
import ml.combust.bundle.dsl.HasAttributes;
import ml.combust.bundle.dsl.Model;
import ml.combust.bundle.dsl.NodeShape;
import ml.combust.bundle.dsl.Value;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.op.OpModel;
import ml.combust.mleap.tensor.DenseTensor;
import ml.combust.mleap.tensor.Tensor;
import org.apache.spark.ml.bundle.ParamSpec;
import org.apache.spark.ml.bundle.ParamSpec$;
import org.apache.spark.ml.bundle.SimpleParamSpec;
import org.apache.spark.ml.bundle.SimpleSparkOp;
import org.apache.spark.ml.bundle.SparkBundleContext;
import org.apache.spark.ml.bundle.ops.classification.LogisticRegressionOp$;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.Matrices$;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.Param;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Predef;
import scala.Predef$;
import scala.Some;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.LambdaDeserialize;
import scala.runtime.java8.JFunction0;

@ScalaSignature(bytes="\u0006\u000194A!\u0003\u0006\u00013!)1\u0005\u0001C\u0001I!9q\u0005\u0001b\u0001\n\u001bA\u0003BB\u0016\u0001A\u00035\u0011\u0006C\u0004-\u0001\t\u0007I\u0011I\u0017\t\rm\u0002\u0001\u0015!\u0003/\u0011\u0015a\u0004\u0001\"\u0011>\u0011\u00159\u0006\u0001\"\u0011Y\u0011\u00159\u0007\u0001\"\u0011i\u0005QaunZ5ti&\u001c'+Z4sKN\u001c\u0018n\u001c8Pa*\u00111\u0002D\u0001\u000fG2\f7o]5gS\u000e\fG/[8o\u0015\tia\"A\u0002paNT!a\u0004\t\u0002\r\t,h\u000e\u001a7f\u0015\t\t\"#\u0001\u0002nY*\u00111\u0003F\u0001\u0006gB\f'o\u001b\u0006\u0003+Y\ta!\u00199bG\",'\"A\f\u0002\u0007=\u0014xm\u0001\u0001\u0014\u0005\u0001Q\u0002cA\u000e\u001d=5\ta\"\u0003\u0002\u001e\u001d\ti1+[7qY\u0016\u001c\u0006/\u0019:l\u001fB\u0004\"aH\u0011\u000e\u0003\u0001R!a\u0003\t\n\u0005\t\u0002#a\u0006'pO&\u001cH/[2SK\u001e\u0014Xm]:j_:lu\u000eZ3m\u0003\u0019a\u0014N\\5u}Q\tQ\u0005\u0005\u0002'\u00015\t!\"A\u0013M\u001f\u001eK5\u000bV%D?J+uIU#T'&{ej\u0018#F\r\u0006+F\nV0U\u0011J+5\u000bS(M\tV\t\u0011fD\u0001+A!y\u0004\u001d\u0001\u0001\u0001\u0001\u0001\u0001\u0011A\n'P\u000f&\u001bF+S\"`%\u0016;%+R*T\u0013>su\fR#G\u0003VcEk\u0018+I%\u0016\u001b\u0006j\u0014'EA\u0005)Qj\u001c3fYV\ta\u0006\u0005\u00030marR\"\u0001\u0019\u000b\u0005E\u0012\u0014AA8q\u0015\ty1G\u0003\u00025k\u000591m\\7ckN$(\"A\t\n\u0005]\u0002$aB(q\u001b>$W\r\u001c\t\u00037eJ!A\u000f\b\u0003%M\u0003\u0018M]6Ck:$G.Z\"p]R,\u0007\u0010^\u0001\u0007\u001b>$W\r\u001c\u0011\u0002\u0013M\u0004\u0018M]6M_\u0006$G\u0003\u0002\u0010?\u001bVCQa\u0010\u0004A\u0002\u0001\u000b1!^5e!\t\t%J\u0004\u0002C\u0011B\u00111IR\u0007\u0002\t*\u0011Q\tG\u0001\u0007yI|w\u000e\u001e \u000b\u0003\u001d\u000bQa]2bY\u0006L!!\u0013$\u0002\rA\u0013X\rZ3g\u0013\tYEJ\u0001\u0004TiJLgn\u001a\u0006\u0003\u0013\u001aCQA\u0014\u0004A\u0002=\u000bQa\u001d5ba\u0016\u0004\"\u0001U*\u000e\u0003ES!A\u0015\u001a\u0002\u0007\u0011\u001cH.\u0003\u0002U#\nIaj\u001c3f'\"\f\u0007/\u001a\u0005\u0006-\u001a\u0001\rAH\u0001\u0006[>$W\r\\\u0001\fgB\f'o[%oaV$8\u000f\u0006\u0002ZKB\u0019!l\u00182\u000f\u0005mkfBA\"]\u0013\u00059\u0015B\u00010G\u0003\u001d\u0001\u0018mY6bO\u0016L!\u0001Y1\u0003\u0007M+\u0017O\u0003\u0002_\rB\u00111dY\u0005\u0003I:\u0011\u0011\u0002U1sC6\u001c\u0006/Z2\t\u000b\u0019<\u0001\u0019\u0001\u0010\u0002\u0007=\u0014'.\u0001\u0007ta\u0006\u00148nT;uaV$8\u000f\u0006\u0002j[B\u0019!l\u00186\u0011\u0005mY\u0017B\u00017\u000f\u0005=\u0019\u0016.\u001c9mKB\u000b'/Y7Ta\u0016\u001c\u0007\"\u00024\t\u0001\u0004q\u0002")
public class LogisticRegressionOp
extends SimpleSparkOp<LogisticRegressionModel> {
    private final OpModel<SparkBundleContext, LogisticRegressionModel> Model = new OpModel<SparkBundleContext, LogisticRegressionModel>(null){
        private final Class<LogisticRegressionModel> klazz;

        public String modelOpName(Object obj, BundleContext context) {
            return OpModel.modelOpName$((OpModel)this, (Object)obj, (BundleContext)context);
        }

        public Class<LogisticRegressionModel> klazz() {
            return this.klazz;
        }

        public String opName() {
            return Bundle.BuiltinOps$.classification$.MODULE$.logistic_regression();
        }

        public Model store(Model model, LogisticRegressionModel obj, BundleContext<SparkBundleContext> context) {
            Model model2;
            Model m = (Model)model.withValue("num_classes", Value$.MODULE$.long((long)obj.numClasses()));
            if (obj.numClasses() > 2) {
                Matrix cm = obj.coefficientMatrix();
                None$ thresholds = obj.isSet((Param)obj.thresholds()) ? new Some((Object)obj.getThresholds()) : None$.MODULE$;
                model2 = (Model)((HasAttributes)((HasAttributes)m.withValue("coefficient_matrix", Value$.MODULE$.tensor((Tensor)new DenseTensor((Object)cm.toArray(), (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{cm.numRows(), cm.numCols()})), ClassTag$.MODULE$.Double())))).withValue("intercept_vector", Value$.MODULE$.vector((Object)obj.interceptVector().toArray(), ClassTag$.MODULE$.Double()))).withValue("thresholds", thresholds.map((Function1 & Serializable & scala.Serializable)x$1 -> new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(x$1)).toSeq()).map((Function1 & Serializable & scala.Serializable)value -> Value$.MODULE$.doubleList(value)));
            } else {
                model2 = (Model)((HasAttributes)((HasAttributes)m.withValue("coefficients", Value$.MODULE$.vector((Object)obj.coefficients().toArray(), ClassTag$.MODULE$.Double()))).withValue("intercept", Value$.MODULE$.double(obj.intercept()))).withValue("threshold", Value$.MODULE$.double(obj.getThreshold()));
            }
            return model2;
        }

        public LogisticRegressionModel load(Model model, BundleContext<SparkBundleContext> context) {
            LogisticRegressionModel logisticRegressionModel;
            long numClasses = model.value("num_classes").getLong();
            if (numClasses > 2L) {
                Tensor cmTensor = model.value("coefficient_matrix").getTensor();
                Matrix coefficientMatrix = Matrices$.MODULE$.dense(BoxesRunTime.unboxToInt((Object)cmTensor.dimensions().head()), BoxesRunTime.unboxToInt((Object)cmTensor.dimensions().apply(1)), (double[])cmTensor.toArray());
                LogisticRegressionModel lr = new LogisticRegressionModel("", coefficientMatrix, Vectors$.MODULE$.dense((double[])model.value("intercept_vector").getTensor().toArray()), (int)numClasses, true);
                logisticRegressionModel = (LogisticRegressionModel)model.getValue("thresholds").map((Function1 & Serializable & scala.Serializable)t -> lr.setThresholds((double[])t.getDoubleList().toArray(ClassTag$.MODULE$.Double()))).getOrElse((Function0 & Serializable & scala.Serializable)() -> lr);
            } else {
                LogisticRegressionModel lr = new LogisticRegressionModel("", Vectors$.MODULE$.dense((double[])model.value("coefficients").getTensor().toArray()), model.value("intercept").getDouble());
                double threshold = BoxesRunTime.unboxToDouble((Object)model.getValue("threshold").map((Function1 & Serializable & scala.Serializable)value -> BoxesRunTime.boxToDouble((double)$anon$1.$anonfun$load$3(value))).getOrElse((Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)() -> 0.5));
                logisticRegressionModel = lr.setThreshold(threshold);
            }
            LogisticRegressionModel r = logisticRegressionModel;
            return r;
        }

        public static final /* synthetic */ double $anonfun$load$3(Value value) {
            return value.getDouble();
        }
        {
            OpModel.$init$((OpModel)this);
            this.klazz = LogisticRegressionModel.class;
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            return LambdaDeserialize.bootstrap("lambdaDeserialize", new MethodHandle[]{$anonfun$store$1(double[] ), $anonfun$store$2(scala.collection.Seq ), $anonfun$load$1(org.apache.spark.ml.classification.LogisticRegressionModel ml.combust.bundle.dsl.Value ), $anonfun$load$2(org.apache.spark.ml.classification.LogisticRegressionModel ), $anonfun$load$3$adapted(ml.combust.bundle.dsl.Value ), $anonfun$load$4()}, serializedLambda);
        }
    };

    private final double LOGISTIC_REGRESSION_DEFAULT_THRESHOLD() {
        return 0.5;
    }

    public OpModel<SparkBundleContext, LogisticRegressionModel> Model() {
        return this.Model;
    }

    public LogisticRegressionModel sparkLoad(String uid, NodeShape shape, LogisticRegressionModel model) {
        LogisticRegressionModel logisticRegressionModel;
        int numClasses = model.numClasses();
        if (numClasses > 2) {
            LogisticRegressionModel lr = new LogisticRegressionModel(uid, model.coefficientMatrix(), model.interceptVector(), numClasses, true);
            Object object = model.isDefined((Param)model.thresholds()) ? lr.setThresholds(model.getThresholds()) : BoxedUnit.UNIT;
            logisticRegressionModel = lr;
        } else {
            LogisticRegressionModel lr = new LogisticRegressionModel(uid, model.coefficientMatrix(), model.interceptVector(), numClasses, false);
            Object object = model.isDefined((Param)model.threshold()) ? lr.setThreshold(model.getThreshold()) : BoxedUnit.UNIT;
            logisticRegressionModel = lr;
        }
        LogisticRegressionModel r = logisticRegressionModel;
        return r;
    }

    public Seq<ParamSpec> sparkInputs(LogisticRegressionModel obj) {
        return (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new SimpleParamSpec[]{ParamSpec$.MODULE$.apply(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"features"), (Object)obj.featuresCol()))}));
    }

    public Seq<SimpleParamSpec> sparkOutputs(LogisticRegressionModel obj) {
        return (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new SimpleParamSpec[]{ParamSpec$.MODULE$.apply(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"raw_prediction"), (Object)obj.rawPredictionCol())), ParamSpec$.MODULE$.apply(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"probability"), (Object)obj.probabilityCol())), ParamSpec$.MODULE$.apply(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"prediction"), (Object)obj.predictionCol()))}));
    }

    public LogisticRegressionOp() {
        super(ClassTag$.MODULE$.apply(LogisticRegressionModel.class));
    }
}

