/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph;

import com.google.auto.value.AutoValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.repackaged.beam_runners_core_construction_java.com.google.common.base.Preconditions;
import org.apache.beam.repackaged.beam_runners_core_construction_java.com.google.common.collect.HashMultimap;
import org.apache.beam.repackaged.beam_runners_core_construction_java.com.google.common.collect.Multimap;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.PTransformTranslation;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.SyntheticComponents;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.AutoValue_OutputDeduplicator_DeduplicationResult;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.AutoValue_OutputDeduplicator_PTransformDeduplication;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.AutoValue_OutputDeduplicator_StageDeduplication;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.AutoValue_OutputDeduplicator_StageOrTransform;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.ImmutableExecutableStage;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.PipelineNode;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.QueryablePipeline;

class OutputDeduplicator {
    OutputDeduplicator() {
    }

    static DeduplicationResult ensureSingleProducer(QueryablePipeline pipeline, Collection<ExecutableStage> stages, Collection<PipelineNode.PTransformNode> unfusedTransforms) {
        RunnerApi.Components.Builder unzippedComponents = pipeline.getComponents().toBuilder();
        Multimap<PipelineNode.PCollectionNode, StageOrTransform> pcollectionProducers = OutputDeduplicator.getProducers(pipeline, stages, unfusedTransforms);
        HashMultimap<StageOrTransform, PipelineNode.PCollectionNode> requiresNewOutput = HashMultimap.create();
        for (Map.Entry<PipelineNode.PCollectionNode, Collection<StageOrTransform>> collectionProducer : pcollectionProducers.asMap().entrySet()) {
            if (collectionProducer.getValue().size() <= 1) continue;
            for (StageOrTransform stageOrTransform : collectionProducer.getValue()) {
                requiresNewOutput.put(stageOrTransform, collectionProducer.getKey());
            }
        }
        LinkedHashMap<ExecutableStage, ExecutableStage> updatedStages = new LinkedHashMap<ExecutableStage, ExecutableStage>();
        LinkedHashMap<String, PipelineNode.PTransformNode> updatedTransforms = new LinkedHashMap<String, PipelineNode.PTransformNode>();
        HashMultimap<String, PipelineNode.PCollectionNode> originalToPartial = HashMultimap.create();
        for (Map.Entry deduplicationTargets : requiresNewOutput.asMap().entrySet()) {
            if (((StageOrTransform)deduplicationTargets.getKey()).getStage() != null) {
                StageDeduplication stageDeduplication = OutputDeduplicator.deduplicatePCollections(((StageOrTransform)deduplicationTargets.getKey()).getStage(), (Collection<PipelineNode.PCollectionNode>)((Collection)deduplicationTargets.getValue()), arg_0 -> ((RunnerApi.Components.Builder)unzippedComponents).containsPcollections(arg_0));
                for (Map.Entry<String, PipelineNode.PCollectionNode> originalToPartialReplacement : stageDeduplication.getOriginalToPartialPCollections().entrySet()) {
                    originalToPartial.put(originalToPartialReplacement.getKey(), originalToPartialReplacement.getValue());
                    unzippedComponents.putPcollections(originalToPartialReplacement.getValue().getId(), originalToPartialReplacement.getValue().getPCollection());
                }
                updatedStages.put(((StageOrTransform)deduplicationTargets.getKey()).getStage(), stageDeduplication.getUpdatedStage());
                continue;
            }
            if (((StageOrTransform)deduplicationTargets.getKey()).getTransform() != null) {
                PTransformDeduplication pTransformDeduplication = OutputDeduplicator.deduplicatePCollections(((StageOrTransform)deduplicationTargets.getKey()).getTransform(), (Collection<PipelineNode.PCollectionNode>)((Collection)deduplicationTargets.getValue()), arg_0 -> ((RunnerApi.Components.Builder)unzippedComponents).containsPcollections(arg_0));
                for (Map.Entry<String, PipelineNode.PCollectionNode> originalToPartialReplacement : pTransformDeduplication.getOriginalToPartialPCollections().entrySet()) {
                    originalToPartial.put(originalToPartialReplacement.getKey(), originalToPartialReplacement.getValue());
                    unzippedComponents.putPcollections(originalToPartialReplacement.getValue().getId(), originalToPartialReplacement.getValue().getPCollection());
                }
                updatedTransforms.put(((StageOrTransform)deduplicationTargets.getKey()).getTransform().getId(), pTransformDeduplication.getUpdatedTransform());
                continue;
            }
            throw new IllegalStateException(String.format("%s with no %s or %s", StageOrTransform.class.getSimpleName(), ExecutableStage.class.getSimpleName(), PipelineNode.PTransformNode.class.getSimpleName()));
        }
        LinkedHashSet<PipelineNode.PTransformNode> linkedHashSet = new LinkedHashSet<PipelineNode.PTransformNode>();
        for (Map.Entry entry : originalToPartial.asMap().entrySet()) {
            String flattenId = SyntheticComponents.uniqueId("unzipped_flatten", arg_0 -> ((RunnerApi.Components.Builder)unzippedComponents).containsTransforms(arg_0));
            RunnerApi.PTransform flattenPartialPCollections = OutputDeduplicator.createFlattenOfPartials(flattenId, (String)entry.getKey(), (Collection)entry.getValue());
            unzippedComponents.putTransforms(flattenId, flattenPartialPCollections);
            linkedHashSet.add(PipelineNode.pTransform(flattenId, flattenPartialPCollections));
        }
        RunnerApi.Components components = unzippedComponents.build();
        return DeduplicationResult.of(components, linkedHashSet, updatedStages, updatedTransforms);
    }

    private static RunnerApi.PTransform createFlattenOfPartials(String transformId, String outputId, Collection<PipelineNode.PCollectionNode> generatedInputs) {
        RunnerApi.PTransform.Builder newFlattenBuilder = RunnerApi.PTransform.newBuilder();
        int i = 0;
        for (PipelineNode.PCollectionNode generatedInput : generatedInputs) {
            String localInputId = String.format("input_%s", i);
            ++i;
            newFlattenBuilder.putInputs(localInputId, generatedInput.getId());
        }
        return newFlattenBuilder.setUniqueName(transformId).putOutputs("output", outputId).setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN)).build();
    }

    private static Multimap<PipelineNode.PCollectionNode, StageOrTransform> getProducers(QueryablePipeline pipeline, Iterable<ExecutableStage> stages, Iterable<PipelineNode.PTransformNode> unfusedTransforms) {
        HashMultimap<PipelineNode.PCollectionNode, StageOrTransform> pcollectionProducers = HashMultimap.create();
        for (ExecutableStage stage : stages) {
            for (PipelineNode.PCollectionNode output : stage.getOutputPCollections()) {
                pcollectionProducers.put(output, StageOrTransform.stage(stage));
            }
        }
        for (PipelineNode.PTransformNode unfused : unfusedTransforms) {
            for (PipelineNode.PCollectionNode output : pipeline.getOutputPCollections(unfused)) {
                pcollectionProducers.put(output, StageOrTransform.transform(unfused));
            }
        }
        return pcollectionProducers;
    }

    private static PTransformDeduplication deduplicatePCollections(PipelineNode.PTransformNode transform, Collection<PipelineNode.PCollectionNode> duplicates, Predicate<String> existingPCollectionIds) {
        Map<String, PipelineNode.PCollectionNode> unzippedOutputs = OutputDeduplicator.createPartialPCollections(duplicates, existingPCollectionIds);
        RunnerApi.PTransform pTransform = OutputDeduplicator.updateOutputs(transform.getTransform(), unzippedOutputs);
        return PTransformDeduplication.of(PipelineNode.pTransform(transform.getId(), pTransform), unzippedOutputs);
    }

    private static StageDeduplication deduplicatePCollections(ExecutableStage stage, Collection<PipelineNode.PCollectionNode> duplicates, Predicate<String> existingPCollectionIds) {
        Map<String, PipelineNode.PCollectionNode> unzippedOutputs = OutputDeduplicator.createPartialPCollections(duplicates, existingPCollectionIds);
        ExecutableStage updatedStage = OutputDeduplicator.deduplicateStageOutput(stage, unzippedOutputs);
        return StageDeduplication.of(updatedStage, unzippedOutputs);
    }

    private static Map<String, PipelineNode.PCollectionNode> createPartialPCollections(Collection<PipelineNode.PCollectionNode> duplicates, Predicate<String> existingPCollectionIds) {
        LinkedHashMap<String, PipelineNode.PCollectionNode> unzippedOutputs = new LinkedHashMap<String, PipelineNode.PCollectionNode>();
        Predicate<String> existingOrNewIds = existingPCollectionIds.or(id -> unzippedOutputs.values().stream().map(PipelineNode.PCollectionNode::getId).anyMatch(id::equals));
        for (PipelineNode.PCollectionNode duplicateOutput : duplicates) {
            String id2 = SyntheticComponents.uniqueId(duplicateOutput.getId(), existingOrNewIds);
            RunnerApi.PCollection partial = duplicateOutput.getPCollection().toBuilder().setUniqueName(id2).build();
            PipelineNode.PCollectionNode alreadyDeduplicated = unzippedOutputs.put(duplicateOutput.getId(), PipelineNode.pCollection(id2, partial));
            Preconditions.checkArgument(alreadyDeduplicated == null, "a duplicate should only appear once per stage");
        }
        return unzippedOutputs;
    }

    private static ExecutableStage deduplicateStageOutput(ExecutableStage stage, Map<String, PipelineNode.PCollectionNode> originalToPartial) {
        ArrayList<PipelineNode.PTransformNode> updatedTransforms = new ArrayList<PipelineNode.PTransformNode>();
        for (PipelineNode.PTransformNode pTransformNode : stage.getTransforms()) {
            RunnerApi.PTransform updatedTransform = OutputDeduplicator.updateOutputs(pTransformNode.getTransform(), originalToPartial);
            updatedTransforms.add(PipelineNode.pTransform(pTransformNode.getId(), updatedTransform));
        }
        ArrayList<PipelineNode.PCollectionNode> updatedOutputs = new ArrayList<PipelineNode.PCollectionNode>();
        for (PipelineNode.PCollectionNode output : stage.getOutputPCollections()) {
            updatedOutputs.add(originalToPartial.getOrDefault(output.getId(), output));
        }
        RunnerApi.Components components = stage.getComponents().toBuilder().clearTransforms().putAllTransforms(updatedTransforms.stream().collect(Collectors.toMap(PipelineNode.PTransformNode::getId, PipelineNode.PTransformNode::getTransform))).putAllPcollections(originalToPartial.values().stream().collect(Collectors.toMap(PipelineNode.PCollectionNode::getId, PipelineNode.PCollectionNode::getPCollection))).build();
        return ImmutableExecutableStage.of(components, stage.getEnvironment(), stage.getInputPCollection(), stage.getSideInputs(), stage.getUserStates(), stage.getTimers(), updatedTransforms, updatedOutputs);
    }

    private static RunnerApi.PTransform updateOutputs(RunnerApi.PTransform transform, Map<String, PipelineNode.PCollectionNode> originalToPartial) {
        RunnerApi.PTransform.Builder updatedTransformBuilder = transform.toBuilder();
        for (Map.Entry output : transform.getOutputsMap().entrySet()) {
            if (!originalToPartial.containsKey(output.getValue())) continue;
            updatedTransformBuilder.putOutputs((String)output.getKey(), originalToPartial.get(output.getValue()).getId());
        }
        return updatedTransformBuilder.build();
    }

    @AutoValue
    static abstract class StageOrTransform {
        StageOrTransform() {
        }

        public static StageOrTransform stage(ExecutableStage stage) {
            return new AutoValue_OutputDeduplicator_StageOrTransform(stage, null);
        }

        public static StageOrTransform transform(PipelineNode.PTransformNode transform) {
            return new AutoValue_OutputDeduplicator_StageOrTransform(null, transform);
        }

        @Nullable
        abstract ExecutableStage getStage();

        @Nullable
        abstract PipelineNode.PTransformNode getTransform();
    }

    @AutoValue
    static abstract class StageDeduplication {
        StageDeduplication() {
        }

        public static StageDeduplication of(ExecutableStage updatedStage, Map<String, PipelineNode.PCollectionNode> originalToPartial) {
            return new AutoValue_OutputDeduplicator_StageDeduplication(updatedStage, originalToPartial);
        }

        abstract ExecutableStage getUpdatedStage();

        abstract Map<String, PipelineNode.PCollectionNode> getOriginalToPartialPCollections();
    }

    @AutoValue
    static abstract class PTransformDeduplication {
        PTransformDeduplication() {
        }

        public static PTransformDeduplication of(PipelineNode.PTransformNode updatedTransform, Map<String, PipelineNode.PCollectionNode> originalToPartial) {
            return new AutoValue_OutputDeduplicator_PTransformDeduplication(updatedTransform, originalToPartial);
        }

        abstract PipelineNode.PTransformNode getUpdatedTransform();

        abstract Map<String, PipelineNode.PCollectionNode> getOriginalToPartialPCollections();
    }

    @AutoValue
    static abstract class DeduplicationResult {
        DeduplicationResult() {
        }

        private static DeduplicationResult of(RunnerApi.Components components, Set<PipelineNode.PTransformNode> introducedTransforms, Map<ExecutableStage, ExecutableStage> stages, Map<String, PipelineNode.PTransformNode> unfused) {
            return new AutoValue_OutputDeduplicator_DeduplicationResult(components, introducedTransforms, stages, unfused);
        }

        abstract RunnerApi.Components getDeduplicatedComponents();

        abstract Set<PipelineNode.PTransformNode> getIntroducedTransforms();

        abstract Map<ExecutableStage, ExecutableStage> getDeduplicatedStages();

        abstract Map<String, PipelineNode.PTransformNode> getDeduplicatedTransforms();
    }
}

