/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.executiongraph.failover.flip1;

import java.util.ArrayDeque;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverRegion;
import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverResultPartition;
import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverStrategy;
import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverTopology;
import org.apache.flink.runtime.executiongraph.failover.flip1.FailoverVertex;
import org.apache.flink.runtime.executiongraph.failover.flip1.PipelinedRegionComputeUtil;
import org.apache.flink.runtime.executiongraph.failover.flip1.ResultPartitionAvailabilityChecker;
import org.apache.flink.runtime.io.network.partition.PartitionException;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RestartPipelinedRegionStrategy
implements FailoverStrategy {
    private static final Logger LOG = LoggerFactory.getLogger(RestartPipelinedRegionStrategy.class);
    private final FailoverTopology<?, ?> topology;
    private final Set<FailoverRegion> regions;
    private final Map<ExecutionVertexID, FailoverRegion> vertexToRegionMap;
    private final RegionFailoverResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker;

    @VisibleForTesting
    public RestartPipelinedRegionStrategy(FailoverTopology<?, ?> topology) {
        this(topology, resultPartitionID -> true);
    }

    public RestartPipelinedRegionStrategy(FailoverTopology<?, ?> topology, ResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker) {
        this.topology = (FailoverTopology)Preconditions.checkNotNull(topology);
        this.regions = Collections.newSetFromMap(new IdentityHashMap());
        this.vertexToRegionMap = new HashMap<ExecutionVertexID, FailoverRegion>();
        this.resultPartitionAvailabilityChecker = new RegionFailoverResultPartitionAvailabilityChecker(resultPartitionAvailabilityChecker);
        LOG.info("Start building failover regions.");
        this.buildFailoverRegions();
    }

    private void buildFailoverRegions() {
        Set distinctRegions = PipelinedRegionComputeUtil.computePipelinedRegions(this.topology);
        for (Set regionVertices : distinctRegions) {
            LOG.debug("Creating a failover region with {} vertices.", (Object)regionVertices.size());
            FailoverRegion failoverRegion = new FailoverRegion(regionVertices);
            this.regions.add(failoverRegion);
            for (FailoverVertex vertex : regionVertices) {
                this.vertexToRegionMap.put((ExecutionVertexID)vertex.getId(), failoverRegion);
            }
        }
        LOG.info("Created {} failover regions.", (Object)this.regions.size());
    }

    @Override
    public Set<ExecutionVertexID> getTasksNeedingRestart(ExecutionVertexID executionVertexId, Throwable cause) {
        LOG.info("Calculating tasks to restart to recover the failed task {}.", (Object)executionVertexId);
        FailoverRegion failedRegion = this.vertexToRegionMap.get(executionVertexId);
        if (failedRegion == null) {
            throw new IllegalStateException("Can not find the failover region for task " + executionVertexId, cause);
        }
        Optional dataConsumptionException = ExceptionUtils.findThrowable((Throwable)cause, PartitionException.class);
        if (dataConsumptionException.isPresent()) {
            this.resultPartitionAvailabilityChecker.markResultPartitionFailed(((PartitionException)dataConsumptionException.get()).getPartitionId().getPartitionId());
        }
        HashSet<ExecutionVertexID> tasksToRestart = new HashSet<ExecutionVertexID>();
        for (FailoverRegion region : this.getRegionsToRestart(failedRegion)) {
            tasksToRestart.addAll(region.getAllExecutionVertexIDs());
        }
        if (dataConsumptionException.isPresent()) {
            this.resultPartitionAvailabilityChecker.removeResultPartitionFromFailedState(((PartitionException)dataConsumptionException.get()).getPartitionId().getPartitionId());
        }
        LOG.info("{} tasks should be restarted to recover the failed task {}. ", (Object)tasksToRestart.size(), (Object)executionVertexId);
        return tasksToRestart;
    }

    private Set<FailoverRegion> getRegionsToRestart(FailoverRegion failedRegion) {
        Set<FailoverRegion> regionsToRestart = Collections.newSetFromMap(new IdentityHashMap());
        Set visitedRegions = Collections.newSetFromMap(new IdentityHashMap());
        ArrayDeque<FailoverRegion> regionsToVisit = new ArrayDeque<FailoverRegion>();
        visitedRegions.add(failedRegion);
        regionsToVisit.add(failedRegion);
        while (!regionsToVisit.isEmpty()) {
            FailoverRegion regionToRestart = (FailoverRegion)regionsToVisit.poll();
            regionsToRestart.add(regionToRestart);
            for (FailoverVertex<?, ?> vertex : regionToRestart.getAllExecutionVertices()) {
                for (FailoverResultPartition consumedPartition : vertex.getConsumedResults()) {
                    FailoverRegion producerRegion;
                    if (this.resultPartitionAvailabilityChecker.isAvailable((IntermediateResultPartitionID)consumedPartition.getId()) || visitedRegions.contains(producerRegion = this.vertexToRegionMap.get(((FailoverVertex)consumedPartition.getProducer()).getId()))) continue;
                    visitedRegions.add(producerRegion);
                    regionsToVisit.add(producerRegion);
                }
            }
            for (FailoverVertex<?, ?> vertex : regionToRestart.getAllExecutionVertices()) {
                for (FailoverResultPartition producedPartition : vertex.getProducedResults()) {
                    for (FailoverVertex consumerVertex : producedPartition.getConsumers()) {
                        FailoverRegion consumerRegion = this.vertexToRegionMap.get(consumerVertex.getId());
                        if (visitedRegions.contains(consumerRegion)) continue;
                        visitedRegions.add(consumerRegion);
                        regionsToVisit.add(consumerRegion);
                    }
                }
            }
        }
        return regionsToRestart;
    }

    @VisibleForTesting
    public FailoverRegion getFailoverRegion(ExecutionVertexID vertexID) {
        return this.vertexToRegionMap.get(vertexID);
    }

    public static class Factory
    implements FailoverStrategy.Factory {
        @Override
        public FailoverStrategy create(FailoverTopology<?, ?> topology, ResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker) {
            return new RestartPipelinedRegionStrategy(topology, resultPartitionAvailabilityChecker);
        }
    }

    private static class RegionFailoverResultPartitionAvailabilityChecker
    implements ResultPartitionAvailabilityChecker {
        private final ResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker;
        private final HashSet<IntermediateResultPartitionID> failedPartitions;

        RegionFailoverResultPartitionAvailabilityChecker(ResultPartitionAvailabilityChecker checker) {
            this.resultPartitionAvailabilityChecker = (ResultPartitionAvailabilityChecker)Preconditions.checkNotNull((Object)checker);
            this.failedPartitions = new HashSet();
        }

        @Override
        public boolean isAvailable(IntermediateResultPartitionID resultPartitionID) {
            return !this.failedPartitions.contains(resultPartitionID) && this.resultPartitionAvailabilityChecker.isAvailable(resultPartitionID);
        }

        public void markResultPartitionFailed(IntermediateResultPartitionID resultPartitionID) {
            this.failedPartitions.add(resultPartitionID);
        }

        public void removeResultPartitionFromFailedState(IntermediateResultPartitionID resultPartitionID) {
            this.failedPartitions.remove(resultPartitionID);
        }
    }
}

