/*
 * Decompiled with CFR 0.152.
 */
package org.cicirello.math.rand;

import java.util.concurrent.ThreadLocalRandom;
import java.util.random.RandomGenerator;
import org.cicirello.math.rand.RandomIndexer;
import org.cicirello.math.rand.RandomVariates;
import org.cicirello.util.ArrayFiller;
import org.cicirello.util.ArrayMinimumLengthEnforcer;

public final class RandomSampler {
    private RandomSampler() {
    }

    public static int[] sampleReservoir(int n, int k, int[] result) {
        return RandomSampler.sampleReservoir(n, k, result, ThreadLocalRandom.current());
    }

    public static int[] sampleReservoir(int n, int k, int[] result, RandomGenerator gen) {
        if (k > n) {
            throw new IllegalArgumentException("k must be no greater than n");
        }
        result = ArrayMinimumLengthEnforcer.enforce(result, k);
        ArrayFiller.fillPartial(result, k);
        for (int i = k; i < n; ++i) {
            int j = RandomIndexer.nextInt(i + 1, gen);
            if (j >= k) continue;
            result[j] = i;
        }
        return result;
    }

    public static int[] samplePool(int n, int k, int[] result) {
        return RandomSampler.samplePool(n, k, result, ThreadLocalRandom.current());
    }

    public static int[] samplePool(int n, int k, int[] result, RandomGenerator gen) {
        if (k > n) {
            throw new IllegalArgumentException("k must be no greater than n");
        }
        result = ArrayMinimumLengthEnforcer.enforce(result, k);
        int[] pool = ArrayFiller.create(n);
        int remaining = n;
        for (int i = 0; i < k; ++i) {
            int temp = RandomIndexer.nextInt(remaining, gen);
            result[i] = pool[temp];
            pool[temp] = pool[--remaining];
        }
        return result;
    }

    public static int[] sampleInsertion(int n, int k, int[] result) {
        return RandomSampler.sampleInsertion(n, k, result, ThreadLocalRandom.current());
    }

    public static int[] sampleInsertion(int n, int k, int[] result, RandomGenerator gen) {
        if (k > n) {
            throw new IllegalArgumentException("k must be no greater than n");
        }
        result = ArrayMinimumLengthEnforcer.enforce(result, k);
        for (int i = 0; i < k; ++i) {
            int temp = RandomIndexer.nextInt(n - i, gen);
            for (int j = k - i; j < k && temp >= result[j]; ++temp, ++j) {
                result[j - 1] = result[j];
            }
            result[j - 1] = temp;
        }
        return result;
    }

    public static int[] sample(int n, double p) {
        return RandomSampler.sample(n, p, ThreadLocalRandom.current());
    }

    public static int[] sample(int n, double p, RandomGenerator r) {
        if (p <= 0.0) {
            return new int[0];
        }
        if (p >= 1.0) {
            return ArrayFiller.create(n);
        }
        return RandomSampler.sample(n, RandomVariates.nextBinomial(n, p, r), null, r);
    }

    public static int[] sample(int n, int k, int[] result) {
        if (k + k < n) {
            if (k * k < n) {
                return RandomSampler.sampleInsertion(n, k, result, ThreadLocalRandom.current());
            }
            return RandomSampler.samplePool(n, k, result, ThreadLocalRandom.current());
        }
        return RandomSampler.sampleReservoir(n, k, result, ThreadLocalRandom.current());
    }

    public static int[] sample(int n, int k, int[] result, RandomGenerator gen) {
        if (k + k < n) {
            if (k * k < n) {
                return RandomSampler.sampleInsertion(n, k, result, gen);
            }
            return RandomSampler.samplePool(n, k, result, gen);
        }
        return RandomSampler.sampleReservoir(n, k, result, gen);
    }
}

