/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.odkl;

import java.io.IOException;
import odkl.analysis.spark.util.RDDOperations$;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.odkl.CRRSamplerEstimator$;
import org.apache.spark.ml.odkl.CRRSamplerModel;
import org.apache.spark.ml.odkl.CRRSamplerParams;
import org.apache.spark.ml.odkl.CRRSamplerParams$class;
import org.apache.spark.ml.odkl.HasGroupByColumns;
import org.apache.spark.ml.odkl.HasGroupByColumns$class;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamMap$;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.Params;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.math.Numeric;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.LongRef;

@ScalaSignature(bytes="\u0006\u000194A!\u0001\u0002\u0001\u001b\t\u00192I\u0015*TC6\u0004H.\u001a:FgRLW.\u0019;pe*\u00111\u0001B\u0001\u0005_\u0012\\GN\u0003\u0002\u0006\r\u0005\u0011Q\u000e\u001c\u0006\u0003\u000f!\tQa\u001d9be.T!!\u0003\u0006\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005Y\u0011aA8sO\u000e\u00011\u0003\u0002\u0001\u000f-q\u00012a\u0004\t\u0013\u001b\u0005!\u0011BA\t\u0005\u0005%)5\u000f^5nCR|'\u000f\u0005\u0002\u0014)5\t!!\u0003\u0002\u0016\u0005\ty1I\u0015*TC6\u0004H.\u001a:N_\u0012,G\u000e\u0005\u0002\u001855\t\u0001D\u0003\u0002\u001a\t\u0005!Q\u000f^5m\u0013\tY\u0002DA\u000bEK\u001a\fW\u000f\u001c;QCJ\fWn],sSR\f'\r\\3\u0011\u0005Mi\u0012B\u0001\u0010\u0003\u0005A\u0019%KU*b[BdWM\u001d)be\u0006l7\u000f\u0003\u0005!\u0001\t\u0015\r\u0011\"\u0011\"\u0003\r)\u0018\u000eZ\u000b\u0002EA\u00111%\u000b\b\u0003I\u001dj\u0011!\n\u0006\u0002M\u0005)1oY1mC&\u0011\u0001&J\u0001\u0007!J,G-\u001a4\n\u0005)Z#AB*ue&twM\u0003\u0002)K!AQ\u0006\u0001B\u0001B\u0003%!%\u0001\u0003vS\u0012\u0004\u0003\"B\u0018\u0001\t\u0003\u0001\u0014A\u0002\u001fj]&$h\b\u0006\u00022eA\u00111\u0003\u0001\u0005\u0006A9\u0002\rA\t\u0005\bi\u0001\u0011\r\u0011\"\u00016\u0003I)\u0007\u0010]3di\u0016$g*^7TC6\u0004H.Z:\u0016\u0003Y\u0002\"a\u000e\u001e\u000e\u0003aR!!\u000f\u0003\u0002\u000bA\f'/Y7\n\u0005mB$\u0001C%oiB\u000b'/Y7\t\ru\u0002\u0001\u0015!\u00037\u0003M)\u0007\u0010]3di\u0016$g*^7TC6\u0004H.Z:!\u0011\u0015y\u0004\u0001\"\u0001A\u0003i\u0019X\r^#ya\u0016\u001cG/\u001a3Ok6\u0014WM](g'\u0006l\u0007\u000f\\3t)\t\t%)D\u0001\u0001\u0011\u0015\u0019e\b1\u0001E\u0003\u00151\u0018\r\\;f!\t!S)\u0003\u0002GK\t\u0019\u0011J\u001c;\t\u000b=\u0002A\u0011\u0001%\u0015\u0003EBQA\u0013\u0001\u0005B-\u000b1AZ5u)\t\u0011B\nC\u0003N\u0013\u0002\u0007a*A\u0004eCR\f7/\u001a;\u0011\u0005=\u0013V\"\u0001)\u000b\u0005E3\u0011aA:rY&\u00111\u000b\u0015\u0002\n\t\u0006$\u0018M\u0012:b[\u0016DQ!\u0016\u0001\u0005BY\u000bAaY8qsR\u0011\u0011g\u0016\u0005\u00061R\u0003\r!W\u0001\u0006Kb$(/\u0019\t\u0003oiK!a\u0017\u001d\u0003\u0011A\u000b'/Y7NCBDQ!\u0018\u0001\u0005By\u000bq\u0002\u001e:b]N4wN]7TG\",W.\u0019\u000b\u0003?\u0016\u0004\"\u0001Y2\u000e\u0003\u0005T!A\u0019)\u0002\u000bQL\b/Z:\n\u0005\u0011\f'AC*ueV\u001cG\u000fV=qK\")a\r\u0018a\u0001?\u000611o\u00195f[\u0006D#\u0001\u00185\u0011\u0005%dW\"\u00016\u000b\u0005-4\u0011AC1o]>$\u0018\r^5p]&\u0011QN\u001b\u0002\r\t\u00164X\r\\8qKJ\f\u0005/\u001b")
public class CRRSamplerEstimator
extends Estimator<CRRSamplerModel>
implements DefaultParamsWritable,
CRRSamplerParams {
    private final String uid;
    private final IntParam expectedNumSamples;
    private final DoubleParam groupSampleRate;
    private final DoubleParam itemSampleRate;
    private final DoubleParam rankingPower;
    private final IntParam shuffleToPartitions;
    private final Param<String> labelCol;
    private final StringArrayParam groupByColumns;
    private final Param<String> inputCol;

    @Override
    public DoubleParam groupSampleRate() {
        return this.groupSampleRate;
    }

    @Override
    public DoubleParam itemSampleRate() {
        return this.itemSampleRate;
    }

    @Override
    public DoubleParam rankingPower() {
        return this.rankingPower;
    }

    @Override
    public IntParam shuffleToPartitions() {
        return this.shuffleToPartitions;
    }

    @Override
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$groupSampleRate_$eq(DoubleParam x$1) {
        this.groupSampleRate = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$itemSampleRate_$eq(DoubleParam x$1) {
        this.itemSampleRate = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$rankingPower_$eq(DoubleParam x$1) {
        this.rankingPower = x$1;
    }

    @Override
    public void org$apache$spark$ml$odkl$CRRSamplerParams$_setter_$shuffleToPartitions_$eq(IntParam x$1) {
        this.shuffleToPartitions = x$1;
    }

    @Override
    public CRRSamplerParams setGroupSampleRate(double value) {
        return CRRSamplerParams$class.setGroupSampleRate(this, value);
    }

    @Override
    public CRRSamplerParams setItemSampleRate(double value) {
        return CRRSamplerParams$class.setItemSampleRate(this, value);
    }

    @Override
    public CRRSamplerParams setRankingPower(double value) {
        return CRRSamplerParams$class.setRankingPower(this, value);
    }

    @Override
    public CRRSamplerParams setShufflerToPartitions(int value) {
        return CRRSamplerParams$class.setShufflerToPartitions(this, value);
    }

    public final Param<String> labelCol() {
        return this.labelCol;
    }

    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param x$1) {
        this.labelCol = x$1;
    }

    public final String getLabelCol() {
        return HasLabelCol.class.getLabelCol((HasLabelCol)this);
    }

    @Override
    public final StringArrayParam groupByColumns() {
        return this.groupByColumns;
    }

    @Override
    public final void org$apache$spark$ml$odkl$HasGroupByColumns$_setter_$groupByColumns_$eq(StringArrayParam x$1) {
        this.groupByColumns = x$1;
    }

    @Override
    public HasGroupByColumns setGroupByColumns(Seq<String> columns) {
        return HasGroupByColumns$class.setGroupByColumns(this, columns);
    }

    public final Param<String> inputCol() {
        return this.inputCol;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param x$1) {
        this.inputCol = x$1;
    }

    public final String getInputCol() {
        return HasInputCol.class.getInputCol((HasInputCol)this);
    }

    public MLWriter write() {
        return DefaultParamsWritable.class.write((DefaultParamsWritable)this);
    }

    public void save(String path) throws IOException {
        MLWritable.class.save((MLWritable)this, (String)path);
    }

    public String uid() {
        return this.uid;
    }

    public IntParam expectedNumSamples() {
        return this.expectedNumSamples;
    }

    public CRRSamplerEstimator setExpectedNumberOfSamples(int value) {
        return (CRRSamplerEstimator)this.set((Param)this.expectedNumSamples(), BoxesRunTime.boxToInteger((int)value));
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public CRRSamplerModel fit(DataFrame dataset) {
        double d;
        if (BoxesRunTime.unboxToDouble((Object)this.$((Param)this.rankingPower())) > 0.0 && this.isDefined((Param)this.groupByColumns()) && ((String[])this.$((Param)this.groupByColumns())).length > 0) {
            Tuple2 tuple2;
            Tuple2 tuple22;
            Tuple2 tuple23;
            if (((String[])this.$((Param)this.groupByColumns())).length == 1) {
                tuple23 = new Tuple2((Object)dataset, (Object)BoxesRunTime.boxToInteger((int)dataset.schema().fieldIndex((String)Predef$.MODULE$.refArrayOps((Object[])this.$((Param)this.groupByColumns())).head())));
            } else {
                Column key = functions$.MODULE$.struct((Seq)Predef$.MODULE$.wrapRefArray((Object[])Predef$.MODULE$.refArrayOps((Object[])this.$((Param)this.groupByColumns())).map((Function1)new Serializable(this, dataset){
                    public static final long serialVersionUID = 0L;
                    private final DataFrame dataset$2;

                    public final Column apply(String x) {
                        return this.dataset$2.apply(x);
                    }
                    {
                        this.dataset$2 = dataset$2;
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))));
                String keyName = new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", "_tmpKey"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{this.uid()}));
                DataFrame data = dataset.withColumn(keyName, key);
                tuple23 = tuple22 = new Tuple2((Object)data, (Object)BoxesRunTime.boxToInteger((int)data.schema().fieldIndex(keyName)));
            }
            if (tuple22 == null) throw new MatchError((Object)tuple22);
            DataFrame withKey = (DataFrame)tuple22._1();
            int keyIndex = tuple22._2$mcI$sp();
            Tuple2 tuple24 = tuple2 = new Tuple2((Object)withKey, (Object)BoxesRunTime.boxToInteger((int)keyIndex));
            DataFrame withKey2 = (DataFrame)tuple24._1();
            int keyIndex2 = tuple24._2$mcI$sp();
            int labelIndex = withKey2.schema().fieldIndex((String)this.$(this.labelCol()));
            d = RDD$.MODULE$.numericRDDToDoubleRDDFunctions(RDDOperations$.MODULE$.ImplicitRDDDecorator(withKey2.rdd(), ClassTag$.MODULE$.apply(Row.class)).groupWithinPartitionsBy(new Serializable(this, keyIndex2){
                public static final long serialVersionUID = 0L;
                private final int keyIndex$2;

                public final Object apply(Row x) {
                    return x.get(this.keyIndex$2);
                }
                {
                    this.keyIndex$2 = keyIndex$2;
                }
            }, ClassTag$.MODULE$.Any()).map((Function1)new Serializable(this, labelIndex){
                public static final long serialVersionUID = 0L;
                public final int labelIndex$3;

                public final long apply(Tuple2<Object, Seq<Row>> x) {
                    LongRef numPositives2 = new LongRef(0L);
                    LongRef numNegatives2 = new LongRef(0L);
                    ((IterableLike)x._2()).foreach((Function1)new Serializable(this, numPositives2, numNegatives2){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ $anonfun$15 $outer;
                        private final LongRef numPositives$1;
                        private final LongRef numNegatives$1;

                        public final void apply(Row row) {
                            if (row.getDouble(this.$outer.labelIndex$3) > 0.0) {
                                ++this.numPositives$1.elem;
                            } else {
                                ++this.numNegatives$1.elem;
                            }
                        }
                        {
                            if ($outer == null) {
                                throw new NullPointerException();
                            }
                            this.$outer = $outer;
                            this.numPositives$1 = numPositives$1;
                            this.numNegatives$1 = numNegatives$1;
                        }
                    });
                    return numNegatives2.elem > 0L && numPositives2.elem > 0L ? numNegatives2.elem + numPositives2.elem : 0L;
                }
                {
                    this.labelIndex$3 = labelIndex$3;
                }
            }, ClassTag$.MODULE$.Long()), (Numeric)Numeric.LongIsIntegral$.MODULE$).sum();
        } else {
            d = dataset.count();
        }
        double totalSamples = d;
        double discountedByGroupRate = totalSamples * BoxesRunTime.unboxToDouble((Object)this.$((Param)this.groupSampleRate()));
        double requiredItemSampleRate = Math.min(1.0, (double)BoxesRunTime.unboxToInt((Object)this.$((Param)this.expectedNumSamples())) / discountedByGroupRate);
        this.logInfo((Function0)new Serializable(this, totalSamples, discountedByGroupRate, requiredItemSampleRate){
            public static final long serialVersionUID = 0L;
            private final double totalSamples$1;
            private final double discountedByGroupRate$1;
            private final double requiredItemSampleRate$1;

            public final String apply() {
                return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Estimated total number of samples ", ", after groups sampling ", ". Chosen item sample rate is ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.totalSamples$1), BoxesRunTime.boxToDouble((double)this.discountedByGroupRate$1), BoxesRunTime.boxToDouble((double)this.requiredItemSampleRate$1)}));
            }
            {
                this.totalSamples$1 = totalSamples$1;
                this.discountedByGroupRate$1 = discountedByGroupRate$1;
                this.requiredItemSampleRate$1 = requiredItemSampleRate$1;
            }
        });
        CRRSamplerModel model = (CRRSamplerModel)this.copyValues((Params)new CRRSamplerModel(), ParamMap$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.itemSampleRate().$minus$greater((Object)BoxesRunTime.boxToDouble((double)requiredItemSampleRate))})));
        return (CRRSamplerModel)model.setParent(this);
    }

    public CRRSamplerEstimator copy(ParamMap extra) {
        return (CRRSamplerEstimator)this.defaultCopy(extra);
    }

    @DeveloperApi
    public StructType transformSchema(StructType schema) {
        return schema;
    }

    public CRRSamplerEstimator(String uid) {
        this.uid = uid;
        MLWritable.class.$init$((MLWritable)this);
        DefaultParamsWritable.class.$init$((DefaultParamsWritable)this);
        HasInputCol.class.$init$((HasInputCol)this);
        HasGroupByColumns$class.$init$(this);
        HasLabelCol.class.$init$((HasLabelCol)this);
        CRRSamplerParams$class.$init$(this);
        this.expectedNumSamples = new IntParam((Identifiable)this, "expectedNumSamples", "The expected number of samples in the result. Required.", (Function1)new $anonfun$5(this));
        this.setDefault((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.groupSampleRate().$minus$greater((Object)BoxesRunTime.boxToDouble((double)1.0))}));
    }

    public CRRSamplerEstimator() {
        this(Identifiable$.MODULE$.randomUID("crrSamplerEstimator"));
    }
}

