/*
 * Decompiled with CFR 0.152.
 */
package org.apache.crunch.lib;

import org.apache.crunch.MapFn;
import org.apache.crunch.PCollection;
import org.apache.crunch.PTable;
import org.apache.crunch.Pair;
import org.apache.crunch.lib.PTables;
import org.apache.crunch.lib.SampleUtils;
import org.apache.crunch.types.PTableType;
import org.apache.crunch.types.PType;
import org.apache.crunch.types.PTypeFamily;

public class Sample {
    public static <S> PCollection<S> sample(PCollection<S> input, double probability) {
        return Sample.sample(input, null, probability);
    }

    public static <S> PCollection<S> sample(PCollection<S> input, Long seed, double probability) {
        String stageName = String.format("sample(%.2f)", probability);
        return input.parallelDo(stageName, new SampleUtils.SampleFn(probability, seed), input.getPType());
    }

    public static <K, V> PTable<K, V> sample(PTable<K, V> input, double probability) {
        return PTables.asPTable(Sample.sample(input, probability));
    }

    public static <K, V> PTable<K, V> sample(PTable<K, V> input, Long seed, double probability) {
        return PTables.asPTable(Sample.sample(input, seed, probability));
    }

    public static <T> PCollection<T> reservoirSample(PCollection<T> input, int sampleSize) {
        return Sample.reservoirSample(input, sampleSize, null);
    }

    public static <T> PCollection<T> reservoirSample(PCollection<T> input, int sampleSize, Long seed) {
        PTypeFamily ptf = input.getTypeFamily();
        PType<Pair<T, Integer>> ptype = ptf.pairs(input.getPType(), ptf.ints());
        return Sample.weightedReservoirSample(input.parallelDo("Map to pairs for reservoir sampling", new MapFn<T, Pair<T, Integer>>(){

            @Override
            public Pair<T, Integer> map(T t) {
                return Pair.of(t, 1);
            }
        }, ptype), sampleSize, seed);
    }

    public static <T, N extends Number> PCollection<T> weightedReservoirSample(PCollection<Pair<T, N>> input, int sampleSize) {
        return Sample.weightedReservoirSample(input, sampleSize, null);
    }

    public static <T, N extends Number> PCollection<T> weightedReservoirSample(PCollection<Pair<T, N>> input, int sampleSize, Long seed) {
        PTypeFamily ptf = input.getTypeFamily();
        PTable<Integer, Pair<T, N>> groupedIn = input.parallelDo(new MapFn<Pair<T, N>, Pair<Integer, Pair<T, N>>>(){

            @Override
            public Pair<Integer, Pair<T, N>> map(Pair<T, N> p) {
                return Pair.of(0, p);
            }
        }, ptf.tableOf(ptf.ints(), input.getPType()));
        int[] ss = new int[]{sampleSize};
        return Sample.groupedWeightedReservoirSample(groupedIn, ss, seed).parallelDo("Extract sampled value from pair", new MapFn<Pair<Integer, T>, T>(){

            @Override
            public T map(Pair<Integer, T> p) {
                return p.second();
            }
        }, input.getPType().getSubTypes().get(0));
    }

    public static <T, N extends Number> PCollection<Pair<Integer, T>> groupedWeightedReservoirSample(PTable<Integer, Pair<T, N>> input, int[] sampleSizes) {
        return Sample.groupedWeightedReservoirSample(input, sampleSizes, null);
    }

    public static <T, N extends Number> PCollection<Pair<Integer, T>> groupedWeightedReservoirSample(PTable<Integer, Pair<T, N>> input, int[] sampleSizes, Long seed) {
        PTypeFamily ptf = input.getTypeFamily();
        PType ttype = input.getPTableType().getValueType().getSubTypes().get(0);
        PTableType ptt = ptf.tableOf(ptf.ints(), ptf.pairs(ptf.doubles(), ttype));
        return input.parallelDo("Initial reservoir sampling", new SampleUtils.ReservoirSampleFn(sampleSizes, seed, ttype), ptt).groupByKey(1).combineValues(new SampleUtils.WRSCombineFn(sampleSizes, ttype)).parallelDo("Extract sampled values", new MapFn<Pair<Integer, Pair<Double, T>>, Pair<Integer, T>>(){

            @Override
            public Pair<Integer, T> map(Pair<Integer, Pair<Double, T>> p) {
                return Pair.of(p.first(), p.second().second());
            }
        }, ptf.pairs(ptf.ints(), ttype));
    }
}

