/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.odkl;

import odkl.analysis.spark.util.Logging;
import odkl.analysis.spark.util.Logging$class;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.odkl.FoldedFeaturesStatsAggregator;
import org.apache.spark.ml.odkl.HasWeights;
import org.apache.spark.ml.odkl.LinearModel;
import org.apache.spark.ml.odkl.LinearModelSignificantFeatureSelector;
import org.apache.spark.ml.odkl.ModelSummary;
import org.apache.spark.ml.odkl.ModelWithSummary;
import org.apache.spark.ml.odkl.SignificantFeatureSelector$;
import org.apache.spark.ml.odkl.SummarizableEstimator;
import org.apache.spark.ml.odkl.UnwrappedStage$;
import org.apache.spark.ml.odkl.WeightsStatRecord;
import org.apache.spark.ml.param.ParamMap$;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.odkl.SparkSqlUtils$;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.reflect.api.TypeTags;
import scala.reflect.runtime.package$;
import scala.runtime.TraitSetter;

public final class SignificantFeatureSelector$
implements Serializable,
Logging {
    public static final SignificantFeatureSelector$ MODULE$;
    private final String WEIGHTS_STAT;
    private transient Logger odkl$analysis$spark$util$Logging$$log_;

    static {
        new SignificantFeatureSelector$();
    }

    @Override
    public Logger odkl$analysis$spark$util$Logging$$log_() {
        return this.odkl$analysis$spark$util$Logging$$log_;
    }

    @Override
    @TraitSetter
    public void odkl$analysis$spark$util$Logging$$log__$eq(Logger x$1) {
        this.odkl$analysis$spark$util$Logging$$log_ = x$1;
    }

    @Override
    public String logName() {
        return Logging$class.logName(this);
    }

    @Override
    public Logger log() {
        return Logging$class.log(this);
    }

    @Override
    public void logInfo(Function0<String> msg) {
        Logging$class.logInfo(this, msg);
    }

    @Override
    public void logDebug(Function0<String> msg) {
        Logging$class.logDebug(this, msg);
    }

    @Override
    public void logTrace(Function0<String> msg) {
        Logging$class.logTrace(this, msg);
    }

    @Override
    public void logWarning(Function0<String> msg) {
        Logging$class.logWarning(this, msg);
    }

    @Override
    public void logError(Function0<String> msg) {
        Logging$class.logError(this, msg);
    }

    @Override
    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging$class.logInfo(this, msg, throwable);
    }

    @Override
    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging$class.logDebug(this, msg, throwable);
    }

    @Override
    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging$class.logTrace(this, msg, throwable);
    }

    @Override
    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging$class.logWarning(this, msg, throwable);
    }

    @Override
    public void logError(Function0<String> msg, Throwable throwable) {
        Logging$class.logError(this, msg, throwable);
    }

    @Override
    public boolean isTraceEnabled() {
        return Logging$class.isTraceEnabled(this);
    }

    public String WEIGHTS_STAT() {
        return this.WEIGHTS_STAT;
    }

    public Option<Vector> tryGetInitials(StructField field) {
        None$ none$;
        if (field.metadata() != null && field.metadata().contains(this.WEIGHTS_STAT())) {
            WeightsStatRecord[] stat = (WeightsStatRecord[])Predef$.MODULE$.refArrayOps((Object[])field.metadata().getMetadataArray(this.WEIGHTS_STAT())).map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final WeightsStatRecord apply(Metadata x) {
                    return new WeightsStatRecord(x);
                }
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(WeightsStatRecord.class)));
            Vector dense = Vectors$.MODULE$.dense((double[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])stat).filter((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final boolean apply(WeightsStatRecord x$6) {
                    return x$6.isRelevant();
                }
            })).sortBy((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final int apply(WeightsStatRecord x$7) {
                    return x$7.index();
                }
            }, (Ordering)Ordering.Int$.MODULE$)).map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final double apply(WeightsStatRecord x$8) {
                    return x$8.average();
                }
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())));
            this.logInfo((Function0<String>)new Serializable(field, dense){
                public static final long serialVersionUID = 0L;
                private final StructField field$1;
                private final Vector dense$1;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Got initial weights for field ", ": ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{this.field$1.name(), this.dense$1}));
                }
                {
                    this.field$1 = field$1;
                    this.dense$1 = dense$1;
                }
            });
            none$ = new Some((Object)dense);
        } else {
            none$ = None$.MODULE$;
        }
        return none$;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public <M extends LinearModel<M>> M transformLinearModel(int originalSize, int[] relevant, M model) {
        ModelSummary nestedSummary = model.summary();
        SparseVector coefficients = new SparseVector(originalSize, relevant, model.getCoefficients().toArray());
        Object object = SparkSqlUtils$.MODULE$.reflectionLock();
        synchronized (object) {
            ModelSummary modelSummary = nestedSummary.transform((Tuple2<ModelWithSummary.Block, Function1<Dataset<Row>, Dataset<Row>>>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)model.weights()), (Object)new Serializable(relevant, model){
                public static final long serialVersionUID = 0L;
                public final int[] relevant$3;
                private final LinearModel model$1;

                public final Dataset<Row> apply(Dataset<Row> data) {
                    UserDefinedFunction reindex = functions$.MODULE$.udf((Function1)new Serializable(this){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ anonfun.17 $outer;

                        public final int apply(int i) {
                            return this.apply$mcII$sp(i);
                        }

                        public int apply$mcII$sp(int i) {
                            return i >= 0 ? this.$outer.relevant$3[i] : i;
                        }
                        {
                            if ($outer == null) {
                                throw null;
                            }
                            this.$outer = $outer;
                        }
                    }, ((TypeTags)package$.MODULE$.universe()).TypeTag().Int(), ((TypeTags)package$.MODULE$.universe()).TypeTag().Int());
                    return data.withColumn(this.model$1.index(), reindex.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{data.apply(this.model$1.index())})));
                }
                {
                    this.relevant$3 = relevant$3;
                    this.model$1 = model$1;
                }
            }), (Seq<Tuple2<ModelWithSummary.Block, Function1<Dataset<Row>, Dataset<Row>>>>)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[0]));
            // MONITOREXIT @DISABLED, blocks:[0, 1] lbl7 : MonitorExitStatement: MONITOREXIT : object
            ModelSummary summary = modelSummary;
            return (M)((LinearModel)model.copy(summary, ParamMap$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{model.coefficients().$minus$greater((Object)coefficients)}))));
        }
    }

    public <SelectingModel extends ModelWithSummary<SelectingModel> & HasWeights, ResultModel extends LinearModel<ResultModel>> SummarizableEstimator<ResultModel> select(SummarizableEstimator<SelectingModel> selector, SummarizableEstimator<ResultModel> estimator, double minSignificance, String featuresCol) {
        LinearModelSignificantFeatureSelector significanceSelector = (LinearModelSignificantFeatureSelector)new LinearModelSignificantFeatureSelector().setMinSignificance(minSignificance);
        significanceSelector.set(significanceSelector.featuresCol().$minus$greater((Object)featuresCol));
        return UnwrappedStage$.MODULE$.wrap(estimator, UnwrappedStage$.MODULE$.dataOnly(significanceSelector, new FoldedFeaturesStatsAggregator<SelectingModel>(selector).setFeaturesCol(featuresCol)));
    }

    public <SelectingModel extends ModelWithSummary<SelectingModel> & HasWeights, ResultModel extends LinearModel<ResultModel>> String select$default$4() {
        return "features";
    }

    private Object readResolve() {
        return MODULE$;
    }

    private SignificantFeatureSelector$() {
        MODULE$ = this;
        Logging$class.$init$(this);
        this.WEIGHTS_STAT = "features_stat";
    }
}

