/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.planner.physical.visitor;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.drill.exec.ops.QueryContext;
import org.apache.drill.exec.planner.physical.BroadcastExchangePrel;
import org.apache.drill.exec.planner.physical.ExchangePrel;
import org.apache.drill.exec.planner.physical.HashAggPrel;
import org.apache.drill.exec.planner.physical.HashJoinPrel;
import org.apache.drill.exec.planner.physical.JoinPrel;
import org.apache.drill.exec.planner.physical.Prel;
import org.apache.drill.exec.planner.physical.RuntimeFilterPrel;
import org.apache.drill.exec.planner.physical.ScanPrel;
import org.apache.drill.exec.planner.physical.SortPrel;
import org.apache.drill.exec.planner.physical.StreamAggPrel;
import org.apache.drill.exec.planner.physical.TopNPrel;
import org.apache.drill.exec.planner.physical.visitor.BasePrelVisitor;
import org.apache.drill.exec.work.filter.BloomFilter;
import org.apache.drill.exec.work.filter.BloomFilterDef;
import org.apache.drill.exec.work.filter.RuntimeFilterDef;
import org.apache.drill.shaded.guava.com.google.common.collect.HashMultimap;
import org.apache.drill.shaded.guava.com.google.common.collect.Multimap;

public class RuntimeFilterVisitor
extends BasePrelVisitor<Prel, Void, RuntimeException> {
    private final Set<ScanPrel> toAddRuntimeFilter = new HashSet<ScanPrel>();
    private final Multimap<ScanPrel, HashJoinPrel> probeSideScan2hj = HashMultimap.create();
    private final double fpp;
    private final int bloomFilterMaxSizeInBytesDef;
    private static final AtomicLong rfIdCounter = new AtomicLong();

    private RuntimeFilterVisitor(QueryContext queryContext) {
        this.bloomFilterMaxSizeInBytesDef = queryContext.getOption((String)"exec.hashjoin.bloom_filter.max.size").num_val.intValue();
        this.fpp = queryContext.getOption((String)"exec.hashjoin.bloom_filter.fpp").float_val;
    }

    public static Prel addRuntimeFilter(Prel prel, QueryContext queryContext) {
        RuntimeFilterVisitor instance = new RuntimeFilterVisitor(queryContext);
        Prel finalPrel = prel.accept(instance, null);
        RuntimeFilterInfoPaddingHelper runtimeFilterInfoPaddingHelper = new RuntimeFilterInfoPaddingHelper();
        runtimeFilterInfoPaddingHelper.visitPrel(finalPrel, null);
        return finalPrel;
    }

    @Override
    public Prel visitPrel(Prel prel, Void value) throws RuntimeException {
        ArrayList<Prel> children = new ArrayList<Prel>();
        for (Prel child : prel) {
            child = child.accept(this, value);
            children.add(child);
        }
        if (children.equals(prel.getInputs())) {
            return prel;
        }
        return (Prel)prel.copy(prel.getTraitSet(), children);
    }

    @Override
    public Prel visitJoin(JoinPrel prel, Void value) throws RuntimeException {
        if (prel instanceof HashJoinPrel) {
            HashJoinPrel hashJoinPrel = (HashJoinPrel)prel;
            RuntimeFilterDef runtimeFilterDef = this.generateRuntimeFilter(hashJoinPrel);
            hashJoinPrel.setRuntimeFilterDef(runtimeFilterDef);
        }
        return this.visitPrel((Prel)prel, value);
    }

    @Override
    public Prel visitScan(ScanPrel prel, Void value) throws RuntimeException {
        if (this.toAddRuntimeFilter.contains(prel)) {
            Collection<HashJoinPrel> hashJoinPrels = this.probeSideScan2hj.get(prel);
            RuntimeFilterPrel runtimeFilterPrel = null;
            for (HashJoinPrel hashJoinPrel : hashJoinPrels) {
                long identifier = rfIdCounter.incrementAndGet();
                hashJoinPrel.getRuntimeFilterDef().setRuntimeFilterIdentifier(identifier);
                if (runtimeFilterPrel == null) {
                    runtimeFilterPrel = new RuntimeFilterPrel(prel, identifier);
                    continue;
                }
                runtimeFilterPrel = new RuntimeFilterPrel(runtimeFilterPrel, identifier);
            }
            return runtimeFilterPrel;
        }
        return prel;
    }

    private RuntimeFilterDef generateRuntimeFilter(HashJoinPrel hashJoinPrel) {
        boolean allowJoin;
        JoinRelType joinRelType = hashJoinPrel.getJoinType();
        JoinInfo joinInfo = hashJoinPrel.analyzeCondition();
        boolean bl = allowJoin = joinInfo.isEqui() && (joinRelType == JoinRelType.INNER || joinRelType == JoinRelType.RIGHT);
        if (!allowJoin) {
            return null;
        }
        ArrayList<BloomFilterDef> bloomFilterDefs = new ArrayList<BloomFilterDef>();
        ScanPrel probeSideScanPrel = null;
        RelNode left = hashJoinPrel.getLeft();
        RelNode right = hashJoinPrel.getRight();
        ExchangePrel exchangePrel = this.findRightExchangePrel(right);
        if (exchangePrel == null) {
            return null;
        }
        List leftFields = left.getRowType().getFieldNames();
        List rightFields = right.getRowType().getFieldNames();
        List<Integer> leftKeys = hashJoinPrel.getLeftKeys();
        List<Integer> rightKeys = hashJoinPrel.getRightKeys();
        RelMetadataQuery metadataQuery = left.getCluster().getMetadataQuery();
        int i = 0;
        for (Integer leftKey : leftKeys) {
            int bloomFilterSizeInBytes;
            boolean encounteredBlockNode;
            String leftFieldName = (String)leftFields.get(leftKey);
            Integer rightKey = rightKeys.get(i++);
            String rightFieldName = (String)rightFields.get(rightKey);
            ScanPrel scanPrel = this.findLeftScanPrel(leftFieldName, left);
            if (scanPrel == null || (encounteredBlockNode = this.containBlockNode((Prel)left, scanPrel))) continue;
            RelDataType scanRowType = scanPrel.getRowType();
            RelDataTypeField field = scanRowType.getField(leftFieldName, true, true);
            int index = field.getIndex();
            Double ndv = metadataQuery.getDistinctRowCount((RelNode)scanPrel, ImmutableBitSet.of((int[])new int[]{index}), null);
            if (ndv == null) {
                ndv = left.estimateRowCount(metadataQuery) * 0.1;
            }
            bloomFilterSizeInBytes = (bloomFilterSizeInBytes = BloomFilter.optimalNumOfBytes(ndv.longValue(), this.fpp)) > this.bloomFilterMaxSizeInBytesDef ? this.bloomFilterMaxSizeInBytesDef : bloomFilterSizeInBytes;
            BloomFilterDef bloomFilterDef = new BloomFilterDef(bloomFilterSizeInBytes, false, leftFieldName, rightFieldName);
            bloomFilterDef.setLeftNDV(ndv);
            bloomFilterDefs.add(bloomFilterDef);
            this.toAddRuntimeFilter.add(scanPrel);
            probeSideScanPrel = scanPrel;
        }
        if (bloomFilterDefs.size() > 0) {
            RuntimeFilterDef runtimeFilterDef = new RuntimeFilterDef(true, false, bloomFilterDefs, false, -1L);
            this.probeSideScan2hj.put(probeSideScanPrel, hashJoinPrel);
            return runtimeFilterDef;
        }
        return null;
    }

    private ScanPrel findLeftScanPrel(String fieldName, RelNode leftRelNode) {
        if (leftRelNode instanceof ScanPrel) {
            RelDataType scanRowType = leftRelNode.getRowType();
            RelDataTypeField field = scanRowType.getField(fieldName, true, true);
            if (field != null) {
                return (ScanPrel)leftRelNode;
            }
            return null;
        }
        if (leftRelNode instanceof RelSubset) {
            RelNode bestNode = ((RelSubset)leftRelNode).getBest();
            if (bestNode != null) {
                return this.findLeftScanPrel(fieldName, bestNode);
            }
            return null;
        }
        List relNodes = leftRelNode.getInputs();
        RelNode leftNode = (RelNode)relNodes.get(0);
        return this.findLeftScanPrel(fieldName, leftNode);
    }

    private ExchangePrel findRightExchangePrel(RelNode rightRelNode) {
        if (rightRelNode instanceof ExchangePrel) {
            return (ExchangePrel)rightRelNode;
        }
        if (rightRelNode instanceof ScanPrel) {
            return null;
        }
        if (rightRelNode instanceof RelSubset) {
            RelNode bestNode = ((RelSubset)rightRelNode).getBest();
            if (bestNode != null) {
                return this.findRightExchangePrel(bestNode);
            }
            return null;
        }
        List relNodes = rightRelNode.getInputs();
        if (relNodes.size() == 1) {
            RelNode leftNode = (RelNode)relNodes.get(0);
            return this.findRightExchangePrel(leftNode);
        }
        return null;
    }

    private boolean containBlockNode(Prel startNode, Prel endNode) {
        BlockNodeVisitor blockNodeVisitor = new BlockNodeVisitor();
        startNode.accept(blockNodeVisitor, endNode);
        return blockNodeVisitor.isEncounteredBlockNode();
    }

    private static class RuntimeFilterInfoPaddingHelper
    extends BasePrelVisitor<Void, RFHelperHolder, RuntimeException> {
        @Override
        public Void visitPrel(Prel prel, RFHelperHolder holder) throws RuntimeException {
            for (Prel child : prel) {
                child.accept(this, holder);
            }
            return null;
        }

        @Override
        public Void visitExchange(ExchangePrel exchange, RFHelperHolder holder) throws RuntimeException {
            if (holder != null && holder.isFromBuildSide()) {
                holder.setBuildSideExchange(exchange);
            }
            return this.visitPrel((Prel)exchange, holder);
        }

        @Override
        public Void visitJoin(JoinPrel prel, RFHelperHolder holder) throws RuntimeException {
            HashJoinPrel hashJoinPrel;
            RuntimeFilterDef runtimeFilterDef;
            boolean isHashJoinPrel = prel instanceof HashJoinPrel;
            if (isHashJoinPrel && (runtimeFilterDef = (hashJoinPrel = (HashJoinPrel)prel).getRuntimeFilterDef()) != null) {
                runtimeFilterDef.setGenerateBloomFilter(true);
                if (holder == null) {
                    holder = new RFHelperHolder();
                }
                Prel left = (Prel)hashJoinPrel.getLeft();
                left.accept(this, holder);
                Prel right = (Prel)hashJoinPrel.getRight();
                holder.setFromBuildSide(true);
                right.accept(this, holder);
                boolean routeToForeman = holder.needToRouteToForeman();
                runtimeFilterDef.setSendToForeman(routeToForeman);
                List<BloomFilterDef> bloomFilterDefs = runtimeFilterDef.getBloomFilterDefs();
                for (BloomFilterDef bloomFilterDef : bloomFilterDefs) {
                    bloomFilterDef.setLocal(!routeToForeman);
                }
            }
            return this.visitPrel((Prel)prel, holder);
        }
    }

    private static class RFHelperHolder {
        private boolean fromBuildSide;
        private ExchangePrel exchangePrel;

        private RFHelperHolder() {
        }

        public void setBuildSideExchange(ExchangePrel exchange) {
            this.exchangePrel = exchange;
        }

        public boolean needToRouteToForeman() {
            return this.exchangePrel != null && !(this.exchangePrel instanceof BroadcastExchangePrel);
        }

        public boolean isFromBuildSide() {
            return this.fromBuildSide;
        }

        public void setFromBuildSide(boolean fromBuildSide) {
            this.fromBuildSide = fromBuildSide;
        }
    }

    private static class BlockNodeVisitor
    extends BasePrelVisitor<Void, Prel, RuntimeException> {
        private boolean encounteredBlockNode;

        private BlockNodeVisitor() {
        }

        @Override
        public Void visitPrel(Prel prel, Prel endValue) throws RuntimeException {
            if (prel == endValue) {
                return null;
            }
            Prel currentPrel = prel instanceof RelSubset ? (Prel)((RelSubset)prel).getBest() : prel;
            if (currentPrel == null) {
                return null;
            }
            if (currentPrel instanceof StreamAggPrel) {
                this.encounteredBlockNode = true;
                return null;
            }
            if (currentPrel instanceof HashAggPrel) {
                this.encounteredBlockNode = true;
                return null;
            }
            if (currentPrel instanceof SortPrel) {
                this.encounteredBlockNode = true;
                return null;
            }
            if (currentPrel instanceof TopNPrel) {
                this.encounteredBlockNode = true;
                return null;
            }
            if (currentPrel instanceof HashJoinPrel) {
                this.encounteredBlockNode = true;
                return null;
            }
            for (Prel subPrel : currentPrel) {
                this.visitPrel(subPrel, endValue);
            }
            return null;
        }

        public boolean isEncounteredBlockNode() {
            return this.encounteredBlockNode;
        }
    }
}

