/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
public class StateAssignmentOperation {
    private static final Logger LOG = LoggerFactory.getLogger(StateAssignmentOperation.class);
    private final Set<ExecutionJobVertex> tasks;
    private final Map<OperatorID, OperatorState> operatorStates;
    private final long restoreCheckpointId;
    private final boolean allowNonRestoredState;

    public StateAssignmentOperation(long restoreCheckpointId, Set<ExecutionJobVertex> tasks, Map<OperatorID, OperatorState> operatorStates, boolean allowNonRestoredState) {
        this.restoreCheckpointId = restoreCheckpointId;
        this.tasks = (Set)Preconditions.checkNotNull(tasks);
        this.operatorStates = (Map)Preconditions.checkNotNull(operatorStates);
        this.allowNonRestoredState = allowNonRestoredState;
    }

    public void assignStates() {
        HashMap<OperatorID, OperatorState> localOperators = new HashMap<OperatorID, OperatorState>(this.operatorStates);
        StateAssignmentOperation.checkStateMappingCompleteness(this.allowNonRestoredState, this.operatorStates, this.tasks);
        for (ExecutionJobVertex executionJobVertex : this.tasks) {
            List<OperatorIDPair> operatorIDPairs = executionJobVertex.getOperatorIDs();
            ArrayList<OperatorState> operatorStates = new ArrayList<OperatorState>(operatorIDPairs.size());
            boolean statelessSubTasks = true;
            for (OperatorIDPair operatorIDPair : operatorIDPairs) {
                OperatorID operatorID = operatorIDPair.getUserDefinedOperatorID().orElse(operatorIDPair.getGeneratedOperatorID());
                OperatorState operatorState = (OperatorState)localOperators.remove((Object)operatorID);
                if (operatorState == null) {
                    operatorState = new OperatorState(operatorID, executionJobVertex.getParallelism(), executionJobVertex.getMaxParallelism());
                } else if (operatorState.getNumberCollectedStates() > 0) {
                    statelessSubTasks = false;
                }
                operatorStates.add(operatorState);
            }
            if (statelessSubTasks) continue;
            this.assignAttemptState(executionJobVertex, operatorStates);
        }
    }

    private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<OperatorState> operatorStates) {
        List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
        this.checkParallelismPreconditions(operatorStates, executionJobVertex);
        int newParallelism = executionJobVertex.getParallelism();
        List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions(executionJobVertex.getMaxParallelism(), newParallelism);
        int expectedNumberOfSubTasks = newParallelism * operatorIDs.size();
        Map<OperatorInstanceID, List<OperatorStateHandle>> newManagedOperatorStates = StateAssignmentOperation.reDistributePartitionableStates(operatorStates, newParallelism, operatorIDs, OperatorSubtaskState::getManagedOperatorState, RoundRobinOperatorStateRepartitioner.INSTANCE);
        Map<OperatorInstanceID, List<OperatorStateHandle>> newRawOperatorStates = StateAssignmentOperation.reDistributePartitionableStates(operatorStates, newParallelism, operatorIDs, OperatorSubtaskState::getRawOperatorState, RoundRobinOperatorStateRepartitioner.INSTANCE);
        Map<OperatorInstanceID, List<InputChannelStateHandle>> newInputChannelState = StateAssignmentOperation.reDistributePartitionableStates(operatorStates, newParallelism, operatorIDs, OperatorSubtaskState::getInputChannelState, StateAssignmentOperation.channelStateNonRescalingRepartitioner("input channel"));
        Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> newResultSubpartitionState = StateAssignmentOperation.reDistributePartitionableStates(operatorStates, newParallelism, operatorIDs, OperatorSubtaskState::getResultSubpartitionState, StateAssignmentOperation.channelStateNonRescalingRepartitioner("result subpartition"));
        HashMap<OperatorInstanceID, List<KeyedStateHandle>> newManagedKeyedState = new HashMap<OperatorInstanceID, List<KeyedStateHandle>>(expectedNumberOfSubTasks);
        HashMap<OperatorInstanceID, List<KeyedStateHandle>> newRawKeyedState = new HashMap<OperatorInstanceID, List<KeyedStateHandle>>(expectedNumberOfSubTasks);
        this.reDistributeKeyedStates(operatorStates, newParallelism, operatorIDs, keyGroupPartitions, newManagedKeyedState, newRawKeyedState);
        this.assignTaskStateToExecutionJobVertices(executionJobVertex, newManagedOperatorStates, newRawOperatorStates, newInputChannelState, newResultSubpartitionState, newManagedKeyedState, newRawKeyedState, newParallelism);
    }

    private void assignTaskStateToExecutionJobVertices(ExecutionJobVertex executionJobVertex, Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState, Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState, Map<OperatorInstanceID, List<InputChannelStateHandle>> inputChannelStates, Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> resultSubpartitionStates, Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState, Map<OperatorInstanceID, List<KeyedStateHandle>> subRawKeyedState, int newParallelism) {
        List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
        for (int subTaskIndex = 0; subTaskIndex < newParallelism; ++subTaskIndex) {
            Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIndex].getCurrentExecutionAttempt();
            TaskStateSnapshot taskState = new TaskStateSnapshot(operatorIDs.size());
            boolean statelessTask = true;
            for (OperatorIDPair operatorID : operatorIDs) {
                OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, operatorID.getGeneratedOperatorID());
                OperatorSubtaskState operatorSubtaskState = StateAssignmentOperation.operatorSubtaskStateFrom(instanceID, subManagedOperatorState, subRawOperatorState, inputChannelStates, resultSubpartitionStates, subManagedKeyedState, subRawKeyedState);
                if (operatorSubtaskState.hasState()) {
                    statelessTask = false;
                }
                taskState.putSubtaskStateByOperatorID(operatorID.getGeneratedOperatorID(), operatorSubtaskState);
            }
            if (statelessTask) continue;
            JobManagerTaskRestore taskRestore = new JobManagerTaskRestore(this.restoreCheckpointId, taskState);
            currentExecutionAttempt.setInitialState(taskRestore);
        }
    }

    public static OperatorSubtaskState operatorSubtaskStateFrom(OperatorInstanceID instanceID, Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState, Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState, Map<OperatorInstanceID, List<InputChannelStateHandle>> inputChannelStates, Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> resultSubpartitionStates, Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState, Map<OperatorInstanceID, List<KeyedStateHandle>> subRawKeyedState) {
        if (!(subManagedOperatorState.containsKey(instanceID) || subRawOperatorState.containsKey(instanceID) || inputChannelStates.containsKey(instanceID) || resultSubpartitionStates.containsKey(instanceID) || subManagedKeyedState.containsKey(instanceID) || subRawKeyedState.containsKey(instanceID))) {
            return new OperatorSubtaskState();
        }
        if (!subManagedKeyedState.containsKey(instanceID)) {
            Preconditions.checkState((!subRawKeyedState.containsKey(instanceID) ? 1 : 0) != 0);
        }
        return new OperatorSubtaskState(new StateObjectCollection<OperatorStateHandle>(subManagedOperatorState.getOrDefault(instanceID, Collections.emptyList())), new StateObjectCollection<OperatorStateHandle>(subRawOperatorState.getOrDefault(instanceID, Collections.emptyList())), new StateObjectCollection<KeyedStateHandle>(subManagedKeyedState.getOrDefault(instanceID, Collections.emptyList())), new StateObjectCollection<KeyedStateHandle>(subRawKeyedState.getOrDefault(instanceID, Collections.emptyList())), new StateObjectCollection<InputChannelStateHandle>(inputChannelStates.getOrDefault(instanceID, Collections.emptyList())), new StateObjectCollection<ResultSubpartitionStateHandle>(resultSubpartitionStates.getOrDefault(instanceID, Collections.emptyList())));
    }

    public void checkParallelismPreconditions(List<OperatorState> operatorStates, ExecutionJobVertex executionJobVertex) {
        for (OperatorState operatorState : operatorStates) {
            StateAssignmentOperation.checkParallelismPreconditions(operatorState, executionJobVertex);
        }
    }

    private void reDistributeKeyedStates(List<OperatorState> oldOperatorStates, int newParallelism, List<OperatorIDPair> newOperatorIDs, List<KeyGroupRange> newKeyGroupPartitions, Map<OperatorInstanceID, List<KeyedStateHandle>> newManagedKeyedState, Map<OperatorInstanceID, List<KeyedStateHandle>> newRawKeyedState) {
        Preconditions.checkState((newOperatorIDs.size() == oldOperatorStates.size() ? 1 : 0) != 0, (Object)"This method still depends on the order of the new and old operators");
        for (int operatorIndex = 0; operatorIndex < newOperatorIDs.size(); ++operatorIndex) {
            OperatorState operatorState = oldOperatorStates.get(operatorIndex);
            int oldParallelism = operatorState.getParallelism();
            for (int subTaskIndex = 0; subTaskIndex < newParallelism; ++subTaskIndex) {
                OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, newOperatorIDs.get(operatorIndex).getGeneratedOperatorID());
                Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> subKeyedStates = this.reAssignSubKeyedStates(operatorState, newKeyGroupPartitions, subTaskIndex, newParallelism, oldParallelism);
                newManagedKeyedState.put(instanceID, (List<KeyedStateHandle>)subKeyedStates.f0);
                newRawKeyedState.put(instanceID, (List<KeyedStateHandle>)subKeyedStates.f1);
            }
        }
    }

    private Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> reAssignSubKeyedStates(OperatorState operatorState, List<KeyGroupRange> keyGroupPartitions, int subTaskIndex, int newParallelism, int oldParallelism) {
        List<Object> subRawKeyedState;
        List<Object> subManagedKeyedState;
        if (newParallelism == oldParallelism) {
            if (operatorState.getState(subTaskIndex) != null) {
                subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState().asList();
                subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState().asList();
            } else {
                subManagedKeyedState = Collections.emptyList();
                subRawKeyedState = Collections.emptyList();
            }
        } else {
            subManagedKeyedState = StateAssignmentOperation.getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
            subRawKeyedState = StateAssignmentOperation.getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
        }
        if (subManagedKeyedState.isEmpty() && subRawKeyedState.isEmpty()) {
            return new Tuple2(Collections.emptyList(), Collections.emptyList());
        }
        return new Tuple2(subManagedKeyedState, subRawKeyedState);
    }

    public static <T extends StateObject> Map<OperatorInstanceID, List<T>> reDistributePartitionableStates(List<OperatorState> oldOperatorStates, int newParallelism, List<OperatorIDPair> newOperatorIDs, Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle, OperatorStateRepartitioner<T> stateRepartitioner) {
        Preconditions.checkState((newOperatorIDs.size() == oldOperatorStates.size() ? 1 : 0) != 0, (Object)"This method still depends on the order of the new and old operators");
        List<List<List<T>>> oldStates = StateAssignmentOperation.splitManagedAndRawOperatorStates(oldOperatorStates, extractHandle);
        HashMap<OperatorInstanceID, List<T>> result = new HashMap<OperatorInstanceID, List<T>>();
        for (int operatorIndex = 0; operatorIndex < newOperatorIDs.size(); ++operatorIndex) {
            result.putAll(StateAssignmentOperation.applyRepartitioner(newOperatorIDs.get(operatorIndex).getGeneratedOperatorID(), stateRepartitioner, oldStates.get(operatorIndex), oldOperatorStates.get(operatorIndex).getParallelism(), newParallelism));
        }
        return result;
    }

    private static <T extends StateObject> List<List<List<T>>> splitManagedAndRawOperatorStates(List<OperatorState> operatorStates, Function<OperatorSubtaskState, StateObjectCollection<T>> extracthandle) {
        ArrayList<List<List<T>>> result = new ArrayList<List<List<T>>>();
        for (OperatorState operatorState : operatorStates) {
            ArrayList statePerSubtask = new ArrayList(operatorState.getParallelism());
            for (int subTaskIndex = 0; subTaskIndex < operatorState.getParallelism(); ++subTaskIndex) {
                OperatorSubtaskState subtaskState = operatorState.getState(subTaskIndex);
                statePerSubtask.add(subtaskState == null ? Collections.emptyList() : extracthandle.apply(subtaskState).asList());
            }
            result.add(statePerSubtask);
        }
        return result;
    }

    public static List<KeyedStateHandle> getManagedKeyedStateHandles(OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
        int parallelism = operatorState.getParallelism();
        ArrayList subtaskKeyedStateHandles = null;
        for (int i = 0; i < parallelism; ++i) {
            if (operatorState.getState(i) == null) continue;
            StateObjectCollection<KeyedStateHandle> keyedStateHandles = operatorState.getState(i).getManagedKeyedState();
            if (subtaskKeyedStateHandles == null) {
                subtaskKeyedStateHandles = new ArrayList(parallelism * keyedStateHandles.size());
            }
            StateAssignmentOperation.extractIntersectingState(keyedStateHandles, subtaskKeyGroupRange, subtaskKeyedStateHandles);
        }
        return subtaskKeyedStateHandles;
    }

    public static List<KeyedStateHandle> getRawKeyedStateHandles(OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
        int parallelism = operatorState.getParallelism();
        ArrayList extractedKeyedStateHandles = null;
        for (int i = 0; i < parallelism; ++i) {
            if (operatorState.getState(i) == null) continue;
            StateObjectCollection<KeyedStateHandle> rawKeyedState = operatorState.getState(i).getRawKeyedState();
            if (extractedKeyedStateHandles == null) {
                extractedKeyedStateHandles = new ArrayList(parallelism * rawKeyedState.size());
            }
            StateAssignmentOperation.extractIntersectingState(rawKeyedState, subtaskKeyGroupRange, extractedKeyedStateHandles);
        }
        return extractedKeyedStateHandles;
    }

    @VisibleForTesting
    public static void extractIntersectingState(Collection<? extends KeyedStateHandle> originalSubtaskStateHandles, KeyGroupRange rangeToExtract, List<KeyedStateHandle> extractedStateCollector) {
        for (KeyedStateHandle keyedStateHandle : originalSubtaskStateHandles) {
            KeyedStateHandle intersectedKeyedStateHandle;
            if (keyedStateHandle == null || (intersectedKeyedStateHandle = keyedStateHandle.getIntersection(rangeToExtract)) == null) continue;
            extractedStateCollector.add(intersectedKeyedStateHandle);
        }
    }

    public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
        Preconditions.checkArgument((numberKeyGroups >= parallelism ? 1 : 0) != 0);
        ArrayList<KeyGroupRange> result = new ArrayList<KeyGroupRange>(parallelism);
        for (int i = 0; i < parallelism; ++i) {
            result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
        }
        return result;
    }

    private static void checkParallelismPreconditions(OperatorState operatorState, ExecutionJobVertex executionJobVertex) {
        if (operatorState.getMaxParallelism() < executionJobVertex.getParallelism()) {
            throw new IllegalStateException("The state for task " + executionJobVertex.getJobVertexId() + " can not be restored. The maximum parallelism (" + operatorState.getMaxParallelism() + ") of the restored state is lower than the configured parallelism (" + executionJobVertex.getParallelism() + "). Please reduce the parallelism of the task to be lower or equal to the maximum parallelism.");
        }
        if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
            if (!executionJobVertex.isMaxParallelismConfigured()) {
                LOG.debug("Overriding maximum parallelism for JobVertex {} from {} to {}", new Object[]{executionJobVertex.getJobVertexId(), executionJobVertex.getMaxParallelism(), operatorState.getMaxParallelism()});
                executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
            } else {
                throw new IllegalStateException("The maximum parallelism (" + operatorState.getMaxParallelism() + ") with which the latest checkpoint of the execution job vertex " + executionJobVertex + " has been taken and the current maximum parallelism (" + executionJobVertex.getMaxParallelism() + ") changed. This is currently not supported.");
            }
        }
    }

    private static void checkStateMappingCompleteness(boolean allowNonRestoredState, Map<OperatorID, OperatorState> operatorStates, Set<ExecutionJobVertex> tasks) {
        HashSet<OperatorID> allOperatorIDs = new HashSet<OperatorID>();
        for (ExecutionJobVertex executionJobVertex : tasks) {
            for (OperatorIDPair operatorIDPair : executionJobVertex.getOperatorIDs()) {
                allOperatorIDs.add(operatorIDPair.getGeneratedOperatorID());
                operatorIDPair.getUserDefinedOperatorID().ifPresent(allOperatorIDs::add);
            }
        }
        for (Map.Entry entry : operatorStates.entrySet()) {
            OperatorState operatorState = (OperatorState)entry.getValue();
            if (allOperatorIDs.contains(entry.getKey())) continue;
            if (allowNonRestoredState) {
                LOG.info("Skipped checkpoint state for operator {}.", (Object)operatorState.getOperatorID());
                continue;
            }
            throw new IllegalStateException("There is no operator for the state " + (Object)((Object)operatorState.getOperatorID()));
        }
    }

    public static <T extends StateObject> Map<OperatorInstanceID, List<T>> applyRepartitioner(OperatorID operatorID, OperatorStateRepartitioner<T> opStateRepartitioner, List<List<T>> chainOpParallelStates, int oldParallelism, int newParallelism) {
        List<List<T>> states = StateAssignmentOperation.applyRepartitioner(opStateRepartitioner, chainOpParallelStates, oldParallelism, newParallelism);
        HashMap<OperatorInstanceID, List<T>> result = new HashMap<OperatorInstanceID, List<T>>(states.size());
        for (int subtaskIndex = 0; subtaskIndex < states.size(); ++subtaskIndex) {
            Preconditions.checkNotNull((Object)(states.get(subtaskIndex) != null ? 1 : 0), (String)"states.get(subtaskIndex) is null");
            result.put(OperatorInstanceID.of(subtaskIndex, operatorID), states.get(subtaskIndex));
        }
        return result;
    }

    public static <T> List<List<T>> applyRepartitioner(OperatorStateRepartitioner<T> opStateRepartitioner, List<List<T>> chainOpParallelStates, int oldParallelism, int newParallelism) {
        if (chainOpParallelStates == null) {
            return Collections.emptyList();
        }
        return opStateRepartitioner.repartitionState(chainOpParallelStates, oldParallelism, newParallelism);
    }

    static <T extends AbstractChannelStateHandle<?>> OperatorStateRepartitioner<T> channelStateNonRescalingRepartitioner(String logStateName) {
        return (previousParallelSubtaskStates, oldParallelism, newParallelism) -> {
            Preconditions.checkArgument((oldParallelism == newParallelism || previousParallelSubtaskStates.stream().flatMap(s -> s.stream().map(l -> l.getOffsets())).allMatch(List::isEmpty) ? 1 : 0) != 0, (Object)String.format("rescaling not supported for %s state (old: %d, new: %d)", logStateName, oldParallelism, newParallelism));
            return previousParallelSubtaskStates;
        };
    }
}

