/*
 * Decompiled with CFR 0.152.
 */
package org.pipecraft.pipes.sync.inter.join;

import java.io.File;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.pipecraft.infra.concurrent.FailableFunction;
import org.pipecraft.infra.io.FileUtils;
import org.pipecraft.infra.io.FileWriteOptions;
import org.pipecraft.pipes.exceptions.IOPipeException;
import org.pipecraft.pipes.exceptions.PipeException;
import org.pipecraft.pipes.serialization.CodecFactory;
import org.pipecraft.pipes.sync.Pipe;
import org.pipecraft.pipes.sync.inter.CompoundPipe;
import org.pipecraft.pipes.sync.inter.ConcatPipe;
import org.pipecraft.pipes.sync.inter.FilterPipe;
import org.pipecraft.pipes.sync.inter.join.JoinMode;
import org.pipecraft.pipes.sync.inter.join.JoinRecord;
import org.pipecraft.pipes.sync.source.BinInputReaderPipe;
import org.pipecraft.pipes.sync.source.CollectionReaderPipe;
import org.pipecraft.pipes.sync.source.EmptyPipe;
import org.pipecraft.pipes.terminal.SharderByHashPipe;
import org.pipecraft.pipes.utils.PipeSupplier;

public class HashJoinPipe<K, L, R>
extends CompoundPipe<JoinRecord<K, L, R>> {
    private final Pipe<L> leftPipe;
    private final FailableFunction<L, K, PipeException> leftKeyExtractor;
    private final List<? extends Pipe<R>> rightPipes;
    private final FailableFunction<R, K, PipeException> rightKeyExtractor;
    private final JoinMode joinMode;
    private final CodecFactory<L> leftCodec;
    private final CodecFactory<R> rightCodec;
    private final File tmpFolder;
    private final int partitionCount;
    private File partitionsFolder;

    public HashJoinPipe(Pipe<L> leftPipe, FailableFunction<L, K, PipeException> leftKeyExtractor, List<? extends Pipe<R>> rightPipes, FailableFunction<R, K, PipeException> rightKeyExtractor, JoinMode joinMode, int partitionCount, CodecFactory<L> leftCodec, CodecFactory<R> rightCodec, File tmpFolder) {
        this.leftPipe = leftPipe;
        this.leftKeyExtractor = leftKeyExtractor;
        this.rightPipes = rightPipes;
        this.rightKeyExtractor = rightKeyExtractor;
        this.joinMode = joinMode;
        this.partitionCount = partitionCount;
        this.leftCodec = leftCodec;
        this.rightCodec = rightCodec;
        this.tmpFolder = tmpFolder;
    }

    public HashJoinPipe(Pipe<L> leftPipe, FailableFunction<L, K, PipeException> leftKeyExtractor, Pipe<R> rightPipe, FailableFunction<R, K, PipeException> rightKeyExtractor, JoinMode joinMode, int partitionCount, CodecFactory<L> leftCodec, CodecFactory<R> rightCodec, File tmpFolder) {
        this(leftPipe, leftKeyExtractor, Collections.singletonList(rightPipe), rightKeyExtractor, joinMode, partitionCount, leftCodec, rightCodec, tmpFolder);
    }

    public HashJoinPipe(List<? extends Pipe<R>> rightPipes, FailableFunction<R, K, PipeException> rightKeyExtractor, int partitionCount, CodecFactory<R> rightCodec, File tmpFolder) {
        this((Pipe<Object>)EmptyPipe.instance(), (FailableFunction<Object, Object, PipeException>)v -> null, (List<Pipe<R>>)rightPipes, (FailableFunction<R, Object, PipeException>)rightKeyExtractor, JoinMode.OUTER, partitionCount, (CodecFactory<Object>)null, rightCodec, tmpFolder);
    }

    @Override
    protected Pipe<JoinRecord<K, L, R>> createPipeline() throws PipeException, InterruptedException {
        try {
            this.partitionsFolder = FileUtils.createTempFolder("hashjoin_shards", this.tmpFolder);
            this.partition(this.leftPipe, this.partitionsFolder, "L", this.leftKeyExtractor, this.leftCodec);
            for (int i = 0; i < this.rightPipes.size(); ++i) {
                Pipe<R> rightPipe = this.rightPipes.get(i);
                this.partition(rightPipe, this.partitionsFolder, "R" + i, this.rightKeyExtractor, this.rightCodec);
            }
            ArrayList suppliers = new ArrayList(this.rightPipes.size() + 1);
            for (int partitionInd = 0; partitionInd < this.partitionCount; ++partitionInd) {
                suppliers.add(this.getPartitionResultSupplier(partitionInd, this.partitionsFolder));
            }
            return new ConcatPipe<JoinRecord<K, L, R>>(suppliers);
        }
        catch (IOException e) {
            throw new IOPipeException(e);
        }
    }

    @Override
    public void close() throws IOException {
        super.close();
        FileUtils.deleteFiles(this.partitionsFolder);
    }

    private <Q> void partition(Pipe<Q> pipe, File outputFolder, String filenamePrefix, FailableFunction<Q, K, PipeException> keyExtractor, CodecFactory<Q> codec) throws IOException, PipeException, InterruptedException {
        String[] fileNames = new String[this.partitionCount];
        for (int i = 0; i < fileNames.length; ++i) {
            fileNames[i] = HashJoinPipe.getPartitionFilename(filenamePrefix, i);
        }
        try (SharderByHashPipe<Q> sharder = new SharderByHashPipe<Q>(pipe, codec, keyExtractor, ind -> fileNames[ind], this.partitionCount, outputFolder, new FileWriteOptions());){
            sharder.start();
        }
    }

    private static String getPartitionFilename(String prefix, int index) {
        return prefix + "_" + index;
    }

    private PipeSupplier<JoinRecord<K, L, R>> getPartitionResultSupplier(int partitionInd, File partitionsFolder) {
        return () -> {
            HashMap resultMap = new HashMap();
            File file = new File(partitionsFolder, HashJoinPipe.getPartitionFilename("L", partitionInd));
            if (file.exists()) {
                try (BinInputReaderPipe<L> leftReader = new BinInputReaderPipe<L>(file, this.leftCodec);){
                    L next;
                    leftReader.start();
                    while ((next = leftReader.next()) != null) {
                        Object key = this.leftKeyExtractor.apply(next);
                        JoinRecord jr2 = resultMap.computeIfAbsent(key, k -> new JoinRecord(key));
                        jr2.addLeft(next);
                        resultMap.put(key, jr2);
                    }
                }
                catch (InterruptedException e) {
                    throw new InterruptedIOException();
                }
            }
            for (int rightPipeInd = 0; rightPipeInd < this.rightPipes.size(); ++rightPipeInd) {
                file = new File(partitionsFolder, HashJoinPipe.getPartitionFilename("R" + rightPipeInd, partitionInd));
                if (!file.exists()) continue;
                try (BinInputReaderPipe<R> rightReader = new BinInputReaderPipe<R>(file, this.rightCodec);){
                    R next;
                    rightReader.start();
                    while ((next = rightReader.next()) != null) {
                        K key = this.rightKeyExtractor.apply(next);
                        JoinRecord jr3 = (JoinRecord)resultMap.get(key);
                        if (jr3 == null) {
                            if (this.joinMode != JoinMode.OUTER) continue;
                            jr3 = new JoinRecord(key);
                            jr3.addRight(rightPipeInd, next);
                            resultMap.put(key, jr3);
                            continue;
                        }
                        jr3.addRight(rightPipeInd, next);
                    }
                    continue;
                }
                catch (InterruptedException e) {
                    throw new InterruptedIOException();
                }
            }
            CollectionReaderPipe collectionReader = new CollectionReaderPipe(resultMap.values());
            return new FilterPipe<JoinRecord>(collectionReader, jr -> this.joinMode.shouldOutput((JoinRecord<?, ?, ?>)jr, this.rightPipes.size()));
        };
    }
}

