/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint.channel;

import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateSerializer;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.checkpoint.channel.CheckpointStartRequest;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.CheckpointedStateScope;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.RunnableWithException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@NotThreadSafe
class ChannelStateCheckpointWriter {
    private static final Logger LOG = LoggerFactory.getLogger(ChannelStateCheckpointWriter.class);
    private final DataOutputStream dataStream;
    private final CheckpointStreamFactory.CheckpointStateOutputStream checkpointStream;
    private final ChannelStateWriter.ChannelStateWriteResult result;
    private final Map<InputChannelInfo, AbstractChannelStateHandle.StateContentMetaInfo> inputChannelOffsets = new HashMap<InputChannelInfo, AbstractChannelStateHandle.StateContentMetaInfo>();
    private final Map<ResultSubpartitionInfo, AbstractChannelStateHandle.StateContentMetaInfo> resultSubpartitionOffsets = new HashMap<ResultSubpartitionInfo, AbstractChannelStateHandle.StateContentMetaInfo>();
    private final ChannelStateSerializer serializer;
    private final long checkpointId;
    private boolean allInputsReceived = false;
    private boolean allOutputsReceived = false;
    private final RunnableWithException onComplete;

    ChannelStateCheckpointWriter(CheckpointStartRequest startCheckpointItem, CheckpointStreamFactory streamFactory, ChannelStateSerializer serializer, RunnableWithException onComplete) throws Exception {
        this(startCheckpointItem.getCheckpointId(), startCheckpointItem.getTargetResult(), streamFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE), serializer, onComplete);
    }

    ChannelStateCheckpointWriter(long checkpointId, ChannelStateWriter.ChannelStateWriteResult result, CheckpointStreamFactory.CheckpointStateOutputStream stream, ChannelStateSerializer serializer, RunnableWithException onComplete) throws Exception {
        this(checkpointId, result, serializer, onComplete, stream, new DataOutputStream((OutputStream)((Object)stream)));
    }

    ChannelStateCheckpointWriter(long checkpointId, ChannelStateWriter.ChannelStateWriteResult result, ChannelStateSerializer serializer, RunnableWithException onComplete, CheckpointStreamFactory.CheckpointStateOutputStream checkpointStateOutputStream, DataOutputStream dataStream) throws Exception {
        this.checkpointId = checkpointId;
        this.result = (ChannelStateWriter.ChannelStateWriteResult)Preconditions.checkNotNull((Object)result);
        this.checkpointStream = (CheckpointStreamFactory.CheckpointStateOutputStream)((Object)Preconditions.checkNotNull((Object)((Object)checkpointStateOutputStream)));
        this.serializer = (ChannelStateSerializer)Preconditions.checkNotNull((Object)serializer);
        this.dataStream = (DataOutputStream)Preconditions.checkNotNull((Object)dataStream);
        this.onComplete = (RunnableWithException)Preconditions.checkNotNull((Object)onComplete);
        this.runWithChecks(() -> serializer.writeHeader(dataStream));
    }

    void writeInput(InputChannelInfo info, Buffer buffer) throws Exception {
        this.write(this.inputChannelOffsets, info, buffer, !this.allInputsReceived);
    }

    void writeOutput(ResultSubpartitionInfo info, Buffer buffer) throws Exception {
        this.write(this.resultSubpartitionOffsets, info, buffer, !this.allOutputsReceived);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private <K> void write(Map<K, AbstractChannelStateHandle.StateContentMetaInfo> offsets, K key, Buffer buffer, boolean precondition) throws Exception {
        try {
            if (this.result.isDone()) {
                return;
            }
            this.runWithChecks(() -> {
                Preconditions.checkState((boolean)precondition);
                long offset = this.checkpointStream.getPos();
                this.serializer.writeData(this.dataStream, buffer);
                long size = this.checkpointStream.getPos() - offset;
                offsets.computeIfAbsent(key, unused -> new AbstractChannelStateHandle.StateContentMetaInfo()).withDataAdded(offset, size);
            });
        }
        finally {
            buffer.recycleBuffer();
        }
    }

    void completeInput() throws Exception {
        LOG.debug("complete input, output completed: {}", (Object)this.allOutputsReceived);
        this.complete(!this.allInputsReceived, () -> {
            this.allInputsReceived = true;
        });
    }

    void completeOutput() throws Exception {
        LOG.debug("complete output, input completed: {}", (Object)this.allInputsReceived);
        this.complete(!this.allOutputsReceived, () -> {
            this.allOutputsReceived = true;
        });
    }

    private void complete(boolean precondition, RunnableWithException complete) throws Exception {
        if (this.result.isDone()) {
            this.doComplete(precondition, complete, this.onComplete);
        } else {
            this.runWithChecks(() -> this.doComplete(precondition, complete, this.onComplete, this::finishWriteAndResult));
        }
    }

    private void finishWriteAndResult() throws IOException {
        if (this.inputChannelOffsets.isEmpty() && this.resultSubpartitionOffsets.isEmpty()) {
            this.dataStream.close();
            this.result.inputChannelStateHandles.complete(Collections.emptyList());
            this.result.resultSubpartitionStateHandles.complete(Collections.emptyList());
            return;
        }
        this.dataStream.flush();
        StreamStateHandle underlying = this.checkpointStream.closeAndGetHandle();
        this.complete(underlying, this.result.inputChannelStateHandles, this.inputChannelOffsets, HandleFactory.INPUT_CHANNEL);
        this.complete(underlying, this.result.resultSubpartitionStateHandles, this.resultSubpartitionOffsets, HandleFactory.RESULT_SUBPARTITION);
    }

    private void doComplete(boolean precondition, RunnableWithException complete, RunnableWithException ... callbacks) throws Exception {
        Preconditions.checkArgument((boolean)precondition);
        complete.run();
        if (this.allInputsReceived && this.allOutputsReceived) {
            for (RunnableWithException callback : callbacks) {
                callback.run();
            }
        }
    }

    private <I, H extends AbstractChannelStateHandle<I>> void complete(StreamStateHandle underlying, CompletableFuture<Collection<H>> future, Map<I, AbstractChannelStateHandle.StateContentMetaInfo> offsets, HandleFactory<I, H> handleFactory) throws IOException {
        ArrayList<H> handles = new ArrayList<H>();
        for (Map.Entry<I, AbstractChannelStateHandle.StateContentMetaInfo> e : offsets.entrySet()) {
            handles.add(this.createHandle(handleFactory, underlying, e.getKey(), e.getValue()));
        }
        future.complete(handles);
        LOG.debug("channel state write completed, checkpointId: {}, handles: {}", (Object)this.checkpointId, handles);
    }

    private <I, H extends AbstractChannelStateHandle<I>> H createHandle(HandleFactory<I, H> handleFactory, StreamStateHandle underlying, I channelInfo, AbstractChannelStateHandle.StateContentMetaInfo contentMetaInfo) throws IOException {
        Optional<byte[]> bytes = underlying.asBytesIfInMemory();
        if (bytes.isPresent()) {
            ByteStreamStateHandle extracted = new ByteStreamStateHandle(UUID.randomUUID().toString(), this.serializer.extractAndMerge(bytes.get(), contentMetaInfo.getOffsets()));
            return handleFactory.create(channelInfo, extracted, Collections.singletonList(this.serializer.getHeaderLength()), extracted.getStateSize());
        }
        return handleFactory.create(channelInfo, underlying, contentMetaInfo.getOffsets(), contentMetaInfo.getSize());
    }

    private void runWithChecks(RunnableWithException r) throws Exception {
        try {
            Preconditions.checkState((!this.result.isDone() ? 1 : 0) != 0, (String)"result is already completed", (Object[])new Object[]{this.result});
            r.run();
        }
        catch (Exception e) {
            this.fail(e);
            throw e;
        }
    }

    public void fail(Throwable e) throws Exception {
        this.result.fail(e);
        this.checkpointStream.close();
    }

    private static interface HandleFactory<I, H extends AbstractChannelStateHandle<I>> {
        public static final HandleFactory<InputChannelInfo, InputChannelStateHandle> INPUT_CHANNEL = InputChannelStateHandle::new;
        public static final HandleFactory<ResultSubpartitionInfo, ResultSubpartitionStateHandle> RESULT_SUBPARTITION = ResultSubpartitionStateHandle::new;

        public H create(I var1, StreamStateHandle var2, List<Long> var3, long var4);
    }
}

