/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.optimizer.traversals;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.distributions.CommonRangeBoundaries;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase;
import org.apache.flink.api.common.operators.base.MapOperatorBase;
import org.apache.flink.api.common.operators.base.MapPartitionOperatorBase;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.java.functions.IdPartitioner;
import org.apache.flink.api.java.functions.SampleInCoordinator;
import org.apache.flink.api.java.functions.SampleInPartition;
import org.apache.flink.api.java.sampling.IntermediateSampleData;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.optimizer.costs.Costs;
import org.apache.flink.optimizer.dag.GroupReduceNode;
import org.apache.flink.optimizer.dag.MapNode;
import org.apache.flink.optimizer.dag.MapPartitionNode;
import org.apache.flink.optimizer.dag.TempMode;
import org.apache.flink.optimizer.dataproperties.GlobalProperties;
import org.apache.flink.optimizer.dataproperties.LocalProperties;
import org.apache.flink.optimizer.plan.Channel;
import org.apache.flink.optimizer.plan.IterationPlanNode;
import org.apache.flink.optimizer.plan.NamedChannel;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.PlanNode;
import org.apache.flink.optimizer.plan.SingleInputPlanNode;
import org.apache.flink.optimizer.util.Utils;
import org.apache.flink.runtime.io.network.DataExchangeMode;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.operators.udf.AssignRangeIndex;
import org.apache.flink.runtime.operators.udf.RangeBoundaryBuilder;
import org.apache.flink.runtime.operators.udf.RemoveRangeIndex;
import org.apache.flink.util.Visitor;

public class RangePartitionRewriter
implements Visitor<PlanNode> {
    static final long SEED = 0L;
    static final String SIP_NAME = "RangePartition: LocalSample";
    static final String SIC_NAME = "RangePartition: GlobalSample";
    static final String RB_NAME = "RangePartition: Histogram";
    static final String ARI_NAME = "RangePartition: PreparePartition";
    static final String PR_NAME = "RangePartition: Partition";
    static final int SAMPLES_PER_PARTITION = 1000;
    static final IdPartitioner idPartitioner = new IdPartitioner();
    final OptimizedPlan plan;
    final Set<IterationPlanNode> visitedIterationNodes;

    public RangePartitionRewriter(OptimizedPlan plan) {
        this.plan = plan;
        this.visitedIterationNodes = new HashSet<IterationPlanNode>();
    }

    public boolean preVisit(PlanNode visitable) {
        return true;
    }

    public void postVisit(PlanNode node) {
        IterationPlanNode iNode;
        if (node instanceof IterationPlanNode && !this.visitedIterationNodes.contains(iNode = (IterationPlanNode)((Object)node))) {
            this.visitedIterationNodes.add(iNode);
            iNode.acceptForStepFunction(this);
        }
        Iterable<Channel> inputChannels = node.getInputs();
        for (Channel channel : inputChannels) {
            ShipStrategyType shipStrategy = channel.getShipStrategy();
            if (shipStrategy != ShipStrategyType.PARTITION_RANGE || channel.getDataDistribution() != null) continue;
            if (node.isOnDynamicPath()) {
                throw new InvalidProgramException("Range Partitioning not supported within iterations if users do not supply the data distribution.");
            }
            PlanNode channelSource = channel.getSource();
            List<Channel> newSourceOutputChannels = this.rewriteRangePartitionChannel(channel);
            channelSource.getOutgoingChannels().remove(channel);
            channelSource.getOutgoingChannels().addAll(newSourceOutputChannels);
        }
    }

    private List<Channel> rewriteRangePartitionChannel(Channel channel) {
        ArrayList<Channel> sourceNewOutputChannels = new ArrayList<Channel>();
        PlanNode sourceNode = channel.getSource();
        PlanNode targetNode = channel.getTarget();
        int sourceParallelism = sourceNode.getParallelism();
        int targetParallelism = targetNode.getParallelism();
        Costs defaultZeroCosts = new Costs(0.0, 0.0, 0.0);
        TypeComparatorFactory<?> comparator = Utils.getShipComparator(channel, this.plan.getOriginalPlan().getExecutionConfig());
        int sampleSize = 1000 * targetParallelism;
        SampleInPartition sampleInPartition = new SampleInPartition(false, sampleSize, 0L);
        TypeInformation sourceOutputType = sourceNode.getOptimizerNode().getOperator().getOperatorInfo().getOutputType();
        TypeInformation isdTypeInformation = TypeExtractor.getForClass(IntermediateSampleData.class);
        UnaryOperatorInformation sipOperatorInformation = new UnaryOperatorInformation(sourceOutputType, isdTypeInformation);
        MapPartitionOperatorBase sipOperatorBase = new MapPartitionOperatorBase((MapPartitionFunction)sampleInPartition, sipOperatorInformation, SIP_NAME);
        MapPartitionNode sipNode = new MapPartitionNode((SingleInputOperator<?, ?, ?>)sipOperatorBase);
        Channel sipChannel = new Channel(sourceNode, TempMode.NONE);
        sipChannel.setShipStrategy(ShipStrategyType.FORWARD, DataExchangeMode.PIPELINED);
        SingleInputPlanNode sipPlanNode = new SingleInputPlanNode(sipNode, SIP_NAME, sipChannel, DriverStrategy.MAP_PARTITION);
        sipNode.setParallelism(sourceParallelism);
        sipPlanNode.setParallelism(sourceParallelism);
        sipPlanNode.initProperties(new GlobalProperties(), new LocalProperties());
        sipPlanNode.setCosts(defaultZeroCosts);
        sipChannel.setTarget(sipPlanNode);
        this.plan.getAllNodes().add(sipPlanNode);
        sourceNewOutputChannels.add(sipChannel);
        SampleInCoordinator sampleInCoordinator = new SampleInCoordinator(false, sampleSize, 0L);
        UnaryOperatorInformation sicOperatorInformation = new UnaryOperatorInformation(isdTypeInformation, sourceOutputType);
        GroupReduceOperatorBase sicOperatorBase = new GroupReduceOperatorBase((GroupReduceFunction)sampleInCoordinator, sicOperatorInformation, SIC_NAME);
        GroupReduceNode sicNode = new GroupReduceNode(sicOperatorBase);
        Channel sicChannel = new Channel(sipPlanNode, TempMode.NONE);
        sicChannel.setShipStrategy(ShipStrategyType.FORWARD, DataExchangeMode.PIPELINED);
        SingleInputPlanNode sicPlanNode = new SingleInputPlanNode(sicNode, SIC_NAME, sicChannel, DriverStrategy.ALL_GROUP_REDUCE);
        sicNode.setParallelism(1);
        sicPlanNode.setParallelism(1);
        sicPlanNode.initProperties(new GlobalProperties(), new LocalProperties());
        sicPlanNode.setCosts(defaultZeroCosts);
        sicChannel.setTarget(sicPlanNode);
        sipPlanNode.addOutgoingChannel(sicChannel);
        this.plan.getAllNodes().add(sicPlanNode);
        RangeBoundaryBuilder rangeBoundaryBuilder = new RangeBoundaryBuilder(comparator, targetParallelism);
        TypeInformation rbTypeInformation = TypeExtractor.getForClass(CommonRangeBoundaries.class);
        UnaryOperatorInformation rbOperatorInformation = new UnaryOperatorInformation(sourceOutputType, rbTypeInformation);
        MapPartitionOperatorBase rbOperatorBase = new MapPartitionOperatorBase((MapPartitionFunction)rangeBoundaryBuilder, rbOperatorInformation, RB_NAME);
        MapPartitionNode rbNode = new MapPartitionNode((SingleInputOperator<?, ?, ?>)rbOperatorBase);
        Channel rbChannel = new Channel(sicPlanNode, TempMode.NONE);
        rbChannel.setShipStrategy(ShipStrategyType.FORWARD, DataExchangeMode.PIPELINED);
        SingleInputPlanNode rbPlanNode = new SingleInputPlanNode(rbNode, RB_NAME, rbChannel, DriverStrategy.MAP_PARTITION);
        rbNode.setParallelism(1);
        rbPlanNode.setParallelism(1);
        rbPlanNode.initProperties(new GlobalProperties(), new LocalProperties());
        rbPlanNode.setCosts(defaultZeroCosts);
        rbChannel.setTarget(rbPlanNode);
        sicPlanNode.addOutgoingChannel(rbChannel);
        this.plan.getAllNodes().add(rbPlanNode);
        AssignRangeIndex assignRangeIndex = new AssignRangeIndex(comparator);
        TupleTypeInfo ariOutputTypeInformation = new TupleTypeInfo(new TypeInformation[]{BasicTypeInfo.INT_TYPE_INFO, sourceOutputType});
        UnaryOperatorInformation ariOperatorInformation = new UnaryOperatorInformation(sourceOutputType, (TypeInformation)ariOutputTypeInformation);
        MapPartitionOperatorBase ariOperatorBase = new MapPartitionOperatorBase((MapPartitionFunction)assignRangeIndex, ariOperatorInformation, ARI_NAME);
        MapPartitionNode ariNode = new MapPartitionNode((SingleInputOperator<?, ?, ?>)ariOperatorBase);
        Channel ariChannel = new Channel(sourceNode, TempMode.NONE);
        ariChannel.setShipStrategy(ShipStrategyType.FORWARD, DataExchangeMode.BATCH);
        SingleInputPlanNode ariPlanNode = new SingleInputPlanNode(ariNode, ARI_NAME, ariChannel, DriverStrategy.MAP_PARTITION);
        ariNode.setParallelism(sourceParallelism);
        ariPlanNode.setParallelism(sourceParallelism);
        ariPlanNode.initProperties(new GlobalProperties(), new LocalProperties());
        ariPlanNode.setCosts(defaultZeroCosts);
        ariChannel.setTarget(ariPlanNode);
        this.plan.getAllNodes().add(ariPlanNode);
        sourceNewOutputChannels.add(ariChannel);
        NamedChannel broadcastChannel = new NamedChannel("RangeBoundaries", rbPlanNode);
        broadcastChannel.setShipStrategy(ShipStrategyType.BROADCAST, DataExchangeMode.PIPELINED);
        broadcastChannel.setTarget(ariPlanNode);
        ArrayList<NamedChannel> broadcastChannels = new ArrayList<NamedChannel>(1);
        broadcastChannels.add(broadcastChannel);
        ariPlanNode.setBroadcastInputs(broadcastChannels);
        Channel partChannel = new Channel(ariPlanNode, TempMode.NONE);
        FieldList keys = new FieldList(0);
        partChannel.setShipStrategy(ShipStrategyType.PARTITION_CUSTOM, keys, (Partitioner<?>)idPartitioner, DataExchangeMode.PIPELINED);
        ariPlanNode.addOutgoingChannel(partChannel);
        RemoveRangeIndex partitionIDRemoveWrapper = new RemoveRangeIndex();
        UnaryOperatorInformation prOperatorInformation = new UnaryOperatorInformation((TypeInformation)ariOutputTypeInformation, sourceOutputType);
        MapOperatorBase prOperatorBase = new MapOperatorBase((MapFunction)partitionIDRemoveWrapper, prOperatorInformation, PR_NAME);
        MapNode prRemoverNode = new MapNode((SingleInputOperator<?, ?, ?>)prOperatorBase);
        SingleInputPlanNode prPlanNode = new SingleInputPlanNode(prRemoverNode, PR_NAME, partChannel, DriverStrategy.MAP);
        partChannel.setTarget(prPlanNode);
        prRemoverNode.setParallelism(targetParallelism);
        prPlanNode.setParallelism(targetParallelism);
        GlobalProperties globalProperties = new GlobalProperties();
        globalProperties.setRangePartitioned(new Ordering(0, null, Order.ASCENDING));
        prPlanNode.initProperties(globalProperties, new LocalProperties());
        prPlanNode.setCosts(defaultZeroCosts);
        this.plan.getAllNodes().add(prPlanNode);
        channel.setSource(prPlanNode);
        channel.setShipStrategy(ShipStrategyType.FORWARD, DataExchangeMode.PIPELINED);
        prPlanNode.addOutgoingChannel(channel);
        return sourceNewOutputChannels;
    }
}

