/*
 * Decompiled with CFR 0.152.
 */
package org.apache.crunch.lib.join;

import java.io.Serializable;
import java.util.Random;
import org.apache.crunch.DoFn;
import org.apache.crunch.Emitter;
import org.apache.crunch.MapFn;
import org.apache.crunch.PTable;
import org.apache.crunch.Pair;
import org.apache.crunch.lib.join.DefaultJoinStrategy;
import org.apache.crunch.lib.join.JoinStrategy;
import org.apache.crunch.lib.join.JoinType;
import org.apache.crunch.types.PTableType;
import org.apache.crunch.types.PTypeFamily;

public class ShardedJoinStrategy<K, U, V>
implements JoinStrategy<K, U, V> {
    private JoinStrategy<Pair<K, Integer>, U, V> wrappedJoinStrategy = new DefaultJoinStrategy<Pair<K, Integer>, U, V>();
    private ShardingStrategy<K> shardingStrategy;

    public ShardedJoinStrategy(int numShards) {
        this(new ConstantShardingStrategy(numShards));
    }

    public ShardedJoinStrategy(ShardingStrategy<K> shardingStrategy) {
        this.shardingStrategy = shardingStrategy;
    }

    @Override
    public PTable<K, Pair<U, V>> join(PTable<K, U> left, PTable<K, V> right, JoinType joinType) {
        if (joinType == JoinType.FULL_OUTER_JOIN || joinType == JoinType.LEFT_OUTER_JOIN) {
            throw new UnsupportedOperationException("Join type " + (Object)((Object)joinType) + " not supported by ShardedJoinStrategy");
        }
        PTypeFamily ptf = left.getTypeFamily();
        PTableType<Pair<K, Integer>, U> shardedLeftType = ptf.tableOf(ptf.pairs(left.getKeyType(), ptf.ints()), left.getValueType());
        PTableType<Pair<K, Integer>, V> shardedRightType = ptf.tableOf(ptf.pairs(right.getKeyType(), ptf.ints()), right.getValueType());
        PTableType<K, Pair<U, V>> outputType = ptf.tableOf(left.getKeyType(), ptf.pairs(left.getValueType(), right.getValueType()));
        PTable<Pair<K, Integer>, U> shardedLeft = left.parallelDo("Pre-shard left", new PreShardLeftSideFn(this.shardingStrategy), shardedLeftType);
        PTable<Pair<K, Integer>, V> shardedRight = right.parallelDo("Pre-shard right", new PreShardRightSideFn(this.shardingStrategy), shardedRightType);
        PTable<Pair<K, Integer>, Pair<U, Pair<U, V>>> shardedJoined = this.wrappedJoinStrategy.join(shardedLeft, shardedRight, joinType);
        return shardedJoined.parallelDo("Unshard", new UnshardFn(), (PTableType<Pair<K, Integer>, Pair<U, Pair<U, V>>>)outputType);
    }

    private static class ConstantShardingStrategy<K>
    implements ShardingStrategy<K> {
        private int numShards;

        public ConstantShardingStrategy(int numShards) {
            this.numShards = numShards;
        }

        @Override
        public int getNumShards(K key) {
            return this.numShards;
        }
    }

    private static class UnshardFn<K, U, V>
    extends MapFn<Pair<Pair<K, Integer>, Pair<U, V>>, Pair<K, Pair<U, V>>> {
        private UnshardFn() {
        }

        @Override
        public Pair<K, Pair<U, V>> map(Pair<Pair<K, Integer>, Pair<U, V>> input) {
            return Pair.of(input.first().first(), input.second());
        }
    }

    private static class PreShardRightSideFn<K, V>
    extends MapFn<Pair<K, V>, Pair<Pair<K, Integer>, V>> {
        private ShardingStrategy<K> shardingStrategy;
        private transient Random random;

        public PreShardRightSideFn(ShardingStrategy<K> shardingStrategy) {
            this.shardingStrategy = shardingStrategy;
        }

        @Override
        public void initialize() {
            this.random = new Random(this.getTaskAttemptID().getTaskID().getId());
        }

        @Override
        public Pair<Pair<K, Integer>, V> map(Pair<K, V> input) {
            K key = input.first();
            V value = input.second();
            int numShards = this.shardingStrategy.getNumShards(key);
            if (numShards < 1) {
                throw new IllegalArgumentException("Num shards must be > 0, got " + numShards + " for " + key);
            }
            return Pair.of(Pair.of(key, this.random.nextInt(numShards)), value);
        }
    }

    private static class PreShardLeftSideFn<K, U>
    extends DoFn<Pair<K, U>, Pair<Pair<K, Integer>, U>> {
        private ShardingStrategy<K> shardingStrategy;

        public PreShardLeftSideFn(ShardingStrategy<K> shardingStrategy) {
            this.shardingStrategy = shardingStrategy;
        }

        @Override
        public void process(Pair<K, U> input, Emitter<Pair<Pair<K, Integer>, U>> emitter) {
            K key = input.first();
            int numShards = this.shardingStrategy.getNumShards(key);
            if (numShards < 1) {
                throw new IllegalArgumentException("Num shards must be > 0, got " + numShards + " for " + key);
            }
            for (int i = 0; i < numShards; ++i) {
                emitter.emit(Pair.of(Pair.of(key, i), input.second()));
            }
        }
    }

    public static interface ShardingStrategy<K>
    extends Serializable {
        public int getNumShards(K var1);
    }
}

