/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.odkl;

import org.apache.commons.math3.distribution.TDistribution;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.odkl.HasDescriminantColumn;
import org.apache.spark.ml.odkl.HasFeaturesSignificance;
import org.apache.spark.ml.odkl.HasFeaturesSignificance$class;
import org.apache.spark.ml.odkl.HasWeights;
import org.apache.spark.ml.odkl.HasWeights$class;
import org.apache.spark.ml.odkl.ModelWithSummary;
import org.apache.spark.ml.odkl.SignificantFeatureSelector$;
import org.apache.spark.ml.odkl.SummarizableEstimator;
import org.apache.spark.ml.odkl.WeightsStat;
import org.apache.spark.ml.odkl.WeightsStatRecord;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RelationalGroupedDataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.MetadataBuilder;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function1;
import scala.Function3;
import scala.Predef$;
import scala.Serializable;
import scala.collection.Seq;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.TypeTags;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;

@ScalaSignature(bytes="\u0006\u0001\u0005Mc\u0001B\u0001\u0003\u00015\u0011QDR8mI\u0016$g)Z1ukJ,7o\u0015;biN\fum\u001a:fO\u0006$xN\u001d\u0006\u0003\u0007\u0011\tAa\u001c3lY*\u0011QAB\u0001\u0003[2T!a\u0002\u0005\u0002\u000bM\u0004\u0018M]6\u000b\u0005%Q\u0011AB1qC\u000eDWMC\u0001\f\u0003\ry'oZ\u0002\u0001+\tq\u0011fE\u0003\u0001\u001fM9\"\u0004\u0005\u0002\u0011#5\tA!\u0003\u0002\u0013\t\tYAK]1og\u001a|'/\\3s!\t!R#D\u0001\u0003\u0013\t1\"AA\fICN4U-\u0019;ve\u0016\u001c8+[4oS\u001aL7-\u00198dKB\u0011A\u0003G\u0005\u00033\t\u0011!\u0002S1t/\u0016Lw\r\u001b;t!\tY\u0002%D\u0001\u001d\u0015\tib$\u0001\u0004tQ\u0006\u0014X\r\u001a\u0006\u0003?\u0011\tQ\u0001]1sC6L!!\t\u000f\u0003\u001d!\u000b7OR3biV\u0014Xm]\"pY\"A1\u0005\u0001B\u0001B\u0003%A%\u0001\u0004oKN$X\r\u001a\t\u0004)\u0015:\u0013B\u0001\u0014\u0003\u0005U\u0019V/\\7be&T\u0018M\u00197f\u000bN$\u0018.\\1u_J\u0004\"\u0001K\u0015\r\u0001\u0011)!\u0006\u0001b\u0001W\tq1+\u001a7fGRLgnZ'pI\u0016d\u0017C\u0001\u00173!\ti\u0003'D\u0001/\u0015\u0005y\u0013!B:dC2\f\u0017BA\u0019/\u0005\u001dqu\u000e\u001e5j]\u001e\u00142aM\u001b\u0018\r\u0011!\u0004\u0001\u0001\u001a\u0003\u0019q\u0012XMZ5oK6,g\u000e\u001e \u0011\u0007Q1t%\u0003\u00028\u0005\t\u0001Rj\u001c3fY^KG\u000f[*v[6\f'/\u001f\u0005\ts\u0001\u0011)\u0019!C!u\u0005\u0019Q/\u001b3\u0016\u0003m\u0002\"\u0001P \u000f\u00055j\u0014B\u0001 /\u0003\u0019\u0001&/\u001a3fM&\u0011\u0001)\u0011\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005yr\u0003\u0002C\"\u0001\u0005\u0003\u0005\u000b\u0011B\u001e\u0002\tULG\r\t\u0005\u0006\u000b\u0002!\tAR\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0007\u001dC\u0015\nE\u0002\u0015\u0001\u001dBQa\t#A\u0002\u0011BQ!\u000f#A\u0002mBQa\u0013\u0001\u0005\u00021\u000bab]3u\r\u0016\fG/\u001e:fg\u000e{G\u000e\u0006\u0002N\u001d6\t\u0001\u0001C\u0003P\u0015\u0002\u00071(A\u0003wC2,X\rC\u0003F\u0001\u0011\u0005\u0011\u000b\u0006\u0002H%\")1\u0005\u0015a\u0001I!)A\u000b\u0001C!+\u0006IAO]1og\u001a|'/\u001c\u000b\u0003-*\u0004\"aV4\u000f\u0005a#gBA-c\u001d\tQ\u0016M\u0004\u0002\\A:\u0011AlX\u0007\u0002;*\u0011a\fD\u0001\u0007yI|w\u000e\u001e \n\u0003-I!!\u0003\u0006\n\u0005\u001dA\u0011BA2\u0007\u0003\r\u0019\u0018\u000f\\\u0005\u0003K\u001a\fq\u0001]1dW\u0006<WM\u0003\u0002d\r%\u0011\u0001.\u001b\u0002\n\t\u0006$\u0018M\u0012:b[\u0016T!!\u001a4\t\u000b-\u001c\u0006\u0019\u00017\u0002\u000f\u0011\fG/Y:fiB\u0012QN\u001d\t\u0004]>\fX\"\u00014\n\u0005A4'a\u0002#bi\u0006\u001cX\r\u001e\t\u0003QI$\u0011b\u001d6\u0002\u0002\u0003\u0005)\u0011\u0001;\u0003\u0007}#\u0013'\u0005\u0002-kB\u0011QF^\u0005\u0003o:\u00121!\u00118z\u0011\u0015I\b\u0001\"\u0001{\u00035!xn\u0015;biJ+7m\u001c:egR!1P`A\u0001!\t!B0\u0003\u0002~\u0005\tYq+Z5hQR\u001c8\u000b^1u\u0011\u0015y\b\u00101\u0001W\u00031\u0019\u0018n\u001a8jM&\u001c\u0017M\\2f\u0011\u0019\t\u0019\u0001\u001fa\u0001O\u0005)Qn\u001c3fY\"9\u0011q\u0001\u0001\u0005\u0002\u0005%\u0011!E2p]N$(/^2u\u001b\u0016$\u0018\rZ1uCR1\u00111BA\f\u0003C\u0001B!!\u0004\u0002\u00145\u0011\u0011q\u0002\u0006\u0004\u0003#1\u0017!\u0002;za\u0016\u001c\u0018\u0002BA\u000b\u0003\u001f\u0011\u0001\"T3uC\u0012\fG/\u0019\u0005\t\u00033\t)\u00011\u0001\u0002\u001c\u0005)a-[3mIB!\u0011QBA\u000f\u0013\u0011\ty\"a\u0004\u0003\u0017M#(/^2u\r&,G\u000e\u001a\u0005\t\u0003G\t)\u00011\u0001\u0002&\u0005)1\u000f^1ugB)Q&a\n\u0002,%\u0019\u0011\u0011\u0006\u0018\u0003\u000b\u0005\u0013(/Y=\u0011\u0007Q\ti#C\u0002\u00020\t\u0011\u0011cV3jO\"$8o\u0015;biJ+7m\u001c:e\u0011\u001d\t\u0019\u0004\u0001C!\u0003k\tAaY8qsR\u0019q\"a\u000e\t\u0011\u0005e\u0012\u0011\u0007a\u0001\u0003w\tQ!\u001a=ue\u0006\u0004B!!\u0010\u0002@5\ta$C\u0002\u0002By\u0011\u0001\u0002U1sC6l\u0015\r\u001d\u0005\b\u0003\u000b\u0002A\u0011IA$\u0003=!(/\u00198tM>\u0014XnU2iK6\fG\u0003BA%\u0003\u001f\u0002B!!\u0004\u0002L%!\u0011QJA\b\u0005)\u0019FO];diRK\b/\u001a\u0005\t\u0003#\n\u0019\u00051\u0001\u0002J\u000511o\u00195f[\u0006\u0004")
public class FoldedFeaturesStatsAggregator<SelectingModel extends ModelWithSummary<SelectingModel> & HasWeights>
extends Transformer
implements HasFeaturesSignificance,
HasWeights,
HasFeaturesCol {
    private final SummarizableEstimator<SelectingModel> nested;
    private final String uid;
    private final Param<String> featuresCol;
    private final ModelWithSummary.Block weights;
    private final String index;
    private final String name;
    private final String weight;
    private final String feature_index;
    private final String feature_name;
    private final String average;
    private final String stdDev;
    private final String count;
    private final String significance;

    public final Param<String> featuresCol() {
        return this.featuresCol;
    }

    public final void org$apache$spark$ml$param$shared$HasFeaturesCol$_setter_$featuresCol_$eq(Param x$1) {
        this.featuresCol = x$1;
    }

    public final String getFeaturesCol() {
        return HasFeaturesCol.class.getFeaturesCol((HasFeaturesCol)this);
    }

    @Override
    public ModelWithSummary.Block weights() {
        return this.weights;
    }

    @Override
    public String index() {
        return this.index;
    }

    @Override
    public String name() {
        return this.name;
    }

    @Override
    public String weight() {
        return this.weight;
    }

    @Override
    public void org$apache$spark$ml$odkl$HasWeights$_setter_$weights_$eq(ModelWithSummary.Block x$1) {
        this.weights = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$HasWeights$_setter_$index_$eq(String x$1) {
        this.index = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$HasWeights$_setter_$name_$eq(String x$1) {
        this.name = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$HasWeights$_setter_$weight_$eq(String x$1) {
        this.weight = x$1;
    }

    @Override
    public String feature_index() {
        return this.feature_index;
    }

    @Override
    public String feature_name() {
        return this.feature_name;
    }

    @Override
    public String average() {
        return this.average;
    }

    @Override
    public String stdDev() {
        return this.stdDev;
    }

    @Override
    public String count() {
        return this.count;
    }

    @Override
    public String significance() {
        return this.significance;
    }

    @Override
    public void org$apache$spark$ml$odkl$HasFeaturesSignificance$_setter_$feature_index_$eq(String x$1) {
        this.feature_index = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$HasFeaturesSignificance$_setter_$feature_name_$eq(String x$1) {
        this.feature_name = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$HasFeaturesSignificance$_setter_$average_$eq(String x$1) {
        this.average = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$HasFeaturesSignificance$_setter_$stdDev_$eq(String x$1) {
        this.stdDev = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$HasFeaturesSignificance$_setter_$count_$eq(String x$1) {
        this.count = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$HasFeaturesSignificance$_setter_$significance_$eq(String x$1) {
        this.significance = x$1;
    }

    @Override
    public ModelWithSummary.Block featuresSignificance() {
        return HasFeaturesSignificance$class.featuresSignificance(this);
    }

    public String uid() {
        return this.uid;
    }

    public FoldedFeaturesStatsAggregator<SelectingModel> setFeaturesCol(String value) {
        return (FoldedFeaturesStatsAggregator)this.set(this.featuresCol(), value);
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        RelationalGroupedDataset relationalGroupedDataset;
        ModelWithSummary model = (ModelWithSummary)((Estimator)this.nested).fit(dataset);
        Dataset<Row> weightsDf = model.summary().$(this.weights());
        UserDefinedFunction sig = functions$.MODULE$.udf((Function3)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final double apply(double avg, double std, long N) {
                TDistribution tDist = new TDistribution((double)(N - 1L));
                double critVal = tDist.inverseCumulativeProbability(0.975);
                double confidence = critVal * std / Math.sqrt(N);
                return confidence <= 0.0 ? 0.0 : Math.abs(avg / confidence);
            }
        }, ((TypeTags)package$.MODULE$.universe()).TypeTag().Double(), ((TypeTags)package$.MODULE$.universe()).TypeTag().Double(), ((TypeTags)package$.MODULE$.universe()).TypeTag().Double(), ((TypeTags)package$.MODULE$.universe()).TypeTag().Long());
        ModelWithSummary modelWithSummary = model;
        if (modelWithSummary instanceof HasDescriminantColumn) {
            ModelWithSummary modelWithSummary2 = modelWithSummary;
            relationalGroupedDataset = weightsDf.groupBy((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{weightsDf.apply(this.feature_index()), weightsDf.apply(this.feature_name()), weightsDf.apply(((HasDescriminantColumn)((Object)modelWithSummary2)).getDescriminantColumn())}));
        } else {
            relationalGroupedDataset = weightsDf.groupBy((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{weightsDf.apply(this.feature_index()), weightsDf.apply(this.feature_name())}));
        }
        RelationalGroupedDataset grouped = relationalGroupedDataset;
        Dataset significance = grouped.agg(functions$.MODULE$.avg(weightsDf.apply(this.weight())).as(this.average()), (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.stddev_samp(weightsDf.apply(this.weight())).as(this.stdDev()), functions$.MODULE$.count(weightsDf.apply(this.weight())).as(this.count()), sig.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.avg(weightsDf.apply(this.weight())), functions$.MODULE$.stddev_samp(weightsDf.apply(this.weight())), functions$.MODULE$.count(weightsDf.apply(this.weight()))})).as(this.significance())})).repartition(1);
        WeightsStat statRecords = this.toStatRecords((Dataset<Row>)significance, model);
        return dataset.withColumn((String)this.$(this.featuresCol()), dataset.apply((String)this.$(this.featuresCol())).as((String)this.$(this.featuresCol()), this.constructMetadata(dataset.schema().apply((String)this.$(this.featuresCol())), statRecords.stats())));
    }

    public WeightsStat toStatRecords(Dataset<Row> significance, SelectingModel model) {
        Object object;
        int indexIndex = significance.schema().fieldIndex(this.feature_index());
        int nameIndex = significance.schema().fieldIndex(this.feature_name());
        int averageIndex = significance.schema().fieldIndex(this.average());
        int stdDevIndex = significance.schema().fieldIndex(this.stdDev());
        int countIndex = significance.schema().fieldIndex(this.count());
        int significanceIndex = significance.schema().fieldIndex(this.significance());
        SelectingModel SelectingModel2 = model;
        if (SelectingModel2 instanceof HasDescriminantColumn) {
            SelectingModel SelectingModel3 = SelectingModel2;
            int index = significance.schema().fieldIndex(((HasDescriminantColumn)SelectingModel3).getDescriminantColumn());
            object = new Serializable(this, index){
                public static final long serialVersionUID = 0L;
                private final int index$1;

                public final String apply(Row r) {
                    return r.getString(this.index$1);
                }
                {
                    this.index$1 = index$1;
                }
            };
        } else {
            object = new Serializable(this){
                public static final long serialVersionUID = 0L;

                public final String apply(Row r) {
                    return "";
                }
            };
        }
        Serializable discriminantExtractor = object;
        return new WeightsStat((WeightsStatRecord[])significance.rdd().map((Function1)new Serializable(this, indexIndex, nameIndex, averageIndex, stdDevIndex, countIndex, significanceIndex, (Function1)discriminantExtractor){
            public static final long serialVersionUID = 0L;
            private final int indexIndex$1;
            private final int nameIndex$1;
            private final int averageIndex$1;
            private final int stdDevIndex$1;
            private final int countIndex$1;
            private final int significanceIndex$1;
            private final Function1 discriminantExtractor$1;

            public final WeightsStatRecord apply(Row r) {
                return new WeightsStatRecord(r.getInt(this.indexIndex$1), r.getString(this.nameIndex$1), (String)this.discriminantExtractor$1.apply((Object)r), r.getDouble(this.averageIndex$1), r.getDouble(this.stdDevIndex$1), r.getLong(this.countIndex$1), r.getDouble(this.significanceIndex$1), true);
            }
            {
                void var8_8;
                this.indexIndex$1 = indexIndex$1;
                this.nameIndex$1 = nameIndex$1;
                this.averageIndex$1 = averageIndex$1;
                this.stdDevIndex$1 = stdDevIndex$1;
                this.countIndex$1 = countIndex$1;
                this.significanceIndex$1 = significanceIndex$1;
                this.discriminantExtractor$1 = var8_8;
            }
        }, ClassTag$.MODULE$.apply(WeightsStatRecord.class)).collect());
    }

    public Metadata constructMetadata(StructField field, WeightsStatRecord[] stats) {
        MetadataBuilder builder = new MetadataBuilder();
        builder.putMetadataArray(SignificantFeatureSelector$.MODULE$.WEIGHTS_STAT(), (Metadata[])Predef$.MODULE$.refArrayOps((Object[])stats).map((Function1)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final Metadata apply(WeightsStatRecord x$1) {
                return x$1.toMetadata();
            }
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Metadata.class))));
        Object object = field.metadata() == null ? BoxedUnit.UNIT : builder.withMetadata(field.metadata());
        return builder.build();
    }

    public Transformer copy(ParamMap extra) {
        return new FoldedFeaturesStatsAggregator<SelectingModel>(this.nested.copy(extra));
    }

    public StructType transformSchema(StructType schema) {
        return schema;
    }

    public FoldedFeaturesStatsAggregator(SummarizableEstimator<SelectingModel> nested, String uid) {
        this.nested = nested;
        this.uid = uid;
        HasFeaturesSignificance$class.$init$(this);
        HasWeights$class.$init$(this);
        HasFeaturesCol.class.$init$((HasFeaturesCol)this);
    }

    public FoldedFeaturesStatsAggregator(SummarizableEstimator<SelectingModel> nested) {
        this(nested, Identifiable$.MODULE$.randomUID("foldedFeaturesStatsAggregator"));
    }
}

