/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.TreeSet;
import javax.annotation.Nonnull;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttleImpl;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Bug;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Litmus;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.ReflectUtil;
import org.apache.calcite.util.ReflectiveVisitor;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.flink.calcite.shaded.com.google.common.base.Supplier;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableMap;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableSet;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableSortedMap;
import org.apache.flink.calcite.shaded.com.google.common.collect.Multimap;
import org.apache.flink.calcite.shaded.com.google.common.collect.Multimaps;
import org.apache.flink.calcite.shaded.com.google.common.collect.Sets;
import org.apache.flink.calcite.shaded.com.google.common.collect.SortedSetMultimap;
import org.apache.flink.table.planner.calcite.FlinkRelBuilder;
import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil;
import org.apache.flink.table.planner.plan.utils.FlinkRexUtil;
import org.apache.flink.util.Preconditions;

public class SubQueryDecorrelator
extends RelShuttleImpl {
    private final SubQueryRelDecorrelator decorrelator;
    private final RelBuilder relBuilder;
    private final Map<RexSubQuery, Pair<RelNode, RexNode>> subQueryMap = new HashMap<RexSubQuery, Pair<RelNode, RexNode>>();

    private SubQueryDecorrelator(SubQueryRelDecorrelator decorrelator, RelBuilder relBuilder) {
        this.decorrelator = decorrelator;
        this.relBuilder = relBuilder;
    }

    public static Result decorrelateQuery(RelNode rootRel) {
        int maxCnfNodeCount = FlinkRelOptUtil.getMaxCnfNodeCount(rootRel);
        CorelMapBuilder builder = new CorelMapBuilder(maxCnfNodeCount);
        CorelMap corelMap = builder.build(rootRel);
        if (builder.hasNestedCorScope || builder.hasUnsupportedCorCondition) {
            return null;
        }
        if (!corelMap.hasCorrelation()) {
            return Result.EMPTY;
        }
        RelOptCluster cluster = rootRel.getCluster();
        FlinkRelBuilder relBuilder = new FlinkRelBuilder(cluster.getPlanner().getContext(), cluster, null);
        RexBuilder rexBuilder = cluster.getRexBuilder();
        SubQueryDecorrelator decorrelator = new SubQueryDecorrelator(new SubQueryRelDecorrelator(corelMap, relBuilder, rexBuilder, maxCnfNodeCount), relBuilder);
        rootRel.accept(decorrelator);
        return new Result(decorrelator.subQueryMap);
    }

    @Override
    protected RelNode visitChild(RelNode parent, int i, RelNode input) {
        return super.visitChild(parent, i, SubQueryDecorrelator.stripHep(input));
    }

    @Override
    public RelNode visit(LogicalFilter filter) {
        try {
            this.stack.push(filter);
            filter.getCondition().accept(this.handleSubQuery(filter));
        }
        finally {
            this.stack.pop();
        }
        return super.visit(filter);
    }

    private RexVisitorImpl<Void> handleSubQuery(final RelNode rel) {
        return new RexVisitorImpl<Void>(true){

            @Override
            public Void visitSubQuery(RexSubQuery subQuery) {
                Frame frame;
                RelNode newRel = subQuery.rel;
                if (subQuery.getKind() == SqlKind.IN) {
                    newRel = SubQueryDecorrelator.this.addProjectionForIn(subQuery.rel);
                }
                if ((frame = SubQueryDecorrelator.this.decorrelator.getInvoke(newRel)) != null && frame.c != null) {
                    Frame target = frame;
                    if (subQuery.getKind() == SqlKind.EXISTS) {
                        target = SubQueryDecorrelator.this.addProjectionForExists(frame);
                    }
                    DecorrelateRexShuttle shuttle = new DecorrelateRexShuttle(rel.getRowType(), target.r.getRowType(), rel.getVariablesSet());
                    RexNode newCondition = target.c.accept(shuttle);
                    Pair<RelNode, RexNode> newNodeAndCondition = new Pair<RelNode, RexNode>(target.r, newCondition);
                    SubQueryDecorrelator.this.subQueryMap.put(subQuery, newNodeAndCondition);
                }
                return null;
            }
        };
    }

    private RelNode addProjectionForIn(RelNode relNode) {
        if (relNode instanceof LogicalProject) {
            return relNode;
        }
        RelDataType rowType = relNode.getRowType();
        ArrayList<RexInputRef> projects = new ArrayList<RexInputRef>();
        for (int i = 0; i < rowType.getFieldCount(); ++i) {
            projects.add(RexInputRef.of(i, rowType));
        }
        this.relBuilder.clear();
        this.relBuilder.push(relNode);
        this.relBuilder.project(projects, rowType.getFieldNames(), true);
        return this.relBuilder.build();
    }

    private Frame addProjectionForExists(Frame frame) {
        ArrayList<Integer> corIndices = new ArrayList<Integer>(frame.getCorInputRefIndices());
        RelNode rel = frame.r;
        RelDataType rowType = rel.getRowType();
        if (corIndices.size() == rowType.getFieldCount()) {
            return frame;
        }
        ArrayList<RexInputRef> projects = new ArrayList<RexInputRef>();
        HashMap<Integer, Integer> mapInputToOutput = new HashMap<Integer, Integer>();
        Collections.sort(corIndices);
        int newPos = 0;
        Iterator iterator = corIndices.iterator();
        while (iterator.hasNext()) {
            int index = (Integer)iterator.next();
            projects.add(RexInputRef.of(index, rowType));
            mapInputToOutput.put(index, newPos++);
        }
        this.relBuilder.clear();
        this.relBuilder.push(frame.r);
        this.relBuilder.project(projects);
        RelNode newProject = this.relBuilder.build();
        RexNode newCondition = SubQueryDecorrelator.adjustInputRefs(frame.c, mapInputToOutput, newProject.getRowType());
        return new Frame(rel, newProject, newCondition, new HashMap<Integer, Integer>());
    }

    private static RelNode stripHep(RelNode rel) {
        if (rel instanceof HepRelVertex) {
            HepRelVertex hepRelVertex = (HepRelVertex)rel;
            rel = hepRelVertex.getCurrentRel();
        }
        return rel;
    }

    private static void analyzeCorConditions(final Set<CorrelationId> variableSet, RexNode condition, RexBuilder rexBuilder, int maxCnfNodeCount, List<RexNode> corConditions, List<RexNode> nonCorConditions, List<RexNode> unsupportedCorConditions) {
        RexNode cnf = FlinkRexUtil.toCnf(rexBuilder, maxCnfNodeCount, condition);
        List<RexNode> conjunctions = RelOptUtil.conjunctions(cnf);
        RexVisitorImpl<Boolean> visitor = new RexVisitorImpl<Boolean>(true){

            @Override
            public Boolean visitFieldAccess(RexFieldAccess fieldAccess) {
                RexNode ref = fieldAccess.getReferenceExpr();
                if (ref instanceof RexCorrelVariable) {
                    return this.visitCorrelVariable((RexCorrelVariable)ref);
                }
                return (Boolean)super.visitFieldAccess(fieldAccess);
            }

            @Override
            public Boolean visitCorrelVariable(RexCorrelVariable correlVariable) {
                return variableSet.contains(correlVariable.id);
            }

            @Override
            public Boolean visitSubQuery(RexSubQuery subQuery) {
                ArrayList<Boolean> result = new ArrayList<Boolean>();
                for (RexNode operand : subQuery.operands) {
                    result.add(operand.accept(this));
                }
                if (result.contains(true) || result.contains(false)) {
                    return false;
                }
                return null;
            }

            @Override
            public Boolean visitCall(RexCall call) {
                ArrayList<Boolean> result = new ArrayList<Boolean>();
                for (RexNode operand : call.operands) {
                    result.add(operand.accept(this));
                }
                if (result.contains(false)) {
                    return false;
                }
                if (result.contains(true)) {
                    return call.op.getKind() != SqlKind.OR;
                }
                return null;
            }
        };
        for (RexNode c : conjunctions) {
            Boolean r = c.accept(visitor);
            if (r == null) {
                nonCorConditions.add(c);
                continue;
            }
            if (r.booleanValue()) {
                corConditions.add(c);
                continue;
            }
            unsupportedCorConditions.add(c);
        }
    }

    private static RexNode adjustInputRefs(RexNode c, final Map<Integer, Integer> mapOldToNewIndex, final RelDataType rowType) {
        return c.accept(new RexShuttle(){

            @Override
            public RexNode visitInputRef(RexInputRef inputRef) {
                assert (mapOldToNewIndex.containsKey(inputRef.getIndex()));
                int newIndex = (Integer)mapOldToNewIndex.get(inputRef.getIndex());
                RexInputRef ref = RexInputRef.of(newIndex, rowType);
                if (ref.getIndex() == inputRef.getIndex() && ref.getType() == inputRef.getType()) {
                    return inputRef;
                }
                return ref;
            }
        });
    }

    public static class Result {
        private final ImmutableMap<RexSubQuery, Pair<RelNode, RexNode>> subQueryMap;
        static final Result EMPTY = new Result(new HashMap<RexSubQuery, Pair<RelNode, RexNode>>());

        private Result(Map<RexSubQuery, Pair<RelNode, RexNode>> subQueryMap) {
            this.subQueryMap = ImmutableMap.copyOf(subQueryMap);
        }

        public Pair<RelNode, RexNode> getSubQueryEquivalent(RexSubQuery subQuery) {
            return this.subQueryMap.get(subQuery);
        }
    }

    private static class Frame {
        final RelNode r;
        final RexNode c;
        final ImmutableSortedMap<Integer, Integer> oldToNewOutputs;

        Frame(RelNode oldRel, RelNode newRel, RexNode corCondition, Map<Integer, Integer> oldToNewOutputs) {
            this.r = (RelNode)Preconditions.checkNotNull((Object)newRel);
            this.c = corCondition;
            this.oldToNewOutputs = ImmutableSortedMap.copyOf(oldToNewOutputs);
            assert (Frame.allLessThan(this.oldToNewOutputs.keySet(), oldRel.getRowType().getFieldCount(), Litmus.THROW));
            assert (Frame.allLessThan(this.oldToNewOutputs.values(), this.r.getRowType().getFieldCount(), Litmus.THROW));
        }

        List<Integer> getCorInputRefIndices() {
            List<Object> inputRefIndices = this.c != null ? RelOptUtil.InputFinder.bits(this.c).toList() : new ArrayList();
            return inputRefIndices;
        }

        private static boolean allLessThan(Collection<Integer> integers, int limit, Litmus ret) {
            for (int value2 : integers) {
                if (value2 < limit) continue;
                return ret.fail("out of range; value: {}, limit: {}", value2, limit);
            }
            return ret.succeed();
        }
    }

    private static class CorelMap {
        private final Multimap<RelNode, CorRef> mapRefRelToCorRef;
        private final SortedMap<CorrelationId, RelNode> mapCorToCorRel;
        private final Map<RelNode, Set<CorrelationId>> mapSubQueryNodeToCorSet;

        private CorelMap(Multimap<RelNode, CorRef> mapRefRelToCorRef, SortedMap<CorrelationId, RelNode> mapCorToCorRel, Map<RelNode, Set<CorrelationId>> mapSubQueryNodeToCorSet) {
            this.mapRefRelToCorRef = mapRefRelToCorRef;
            this.mapCorToCorRel = mapCorToCorRel;
            this.mapSubQueryNodeToCorSet = ImmutableMap.copyOf(mapSubQueryNodeToCorSet);
        }

        public String toString() {
            return "mapRefRelToCorRef=" + this.mapRefRelToCorRef + "\nmapCorToCorRel=" + this.mapCorToCorRel + "\nmapSubQueryNodeToCorSet=" + this.mapSubQueryNodeToCorSet + "\n";
        }

        public boolean equals(Object obj) {
            return obj == this || obj instanceof CorelMap && this.mapRefRelToCorRef.equals(((CorelMap)obj).mapRefRelToCorRef) && this.mapCorToCorRel.equals(((CorelMap)obj).mapCorToCorRel) && this.mapSubQueryNodeToCorSet.equals(((CorelMap)obj).mapSubQueryNodeToCorSet);
        }

        public int hashCode() {
            return Objects.hash(this.mapRefRelToCorRef, this.mapCorToCorRel, this.mapSubQueryNodeToCorSet);
        }

        public static CorelMap of(SortedSetMultimap<RelNode, CorRef> mapRefRelToCorVar, SortedMap<CorrelationId, RelNode> mapCorToCorRel, Map<RelNode, Set<CorrelationId>> mapSubQueryNodeToCorSet) {
            return new CorelMap(mapRefRelToCorVar, mapCorToCorRel, mapSubQueryNodeToCorSet);
        }

        boolean hasCorrelation() {
            return !this.mapCorToCorRel.isEmpty();
        }
    }

    private static class CorRef
    implements Comparable<CorRef> {
        final int uniqueKey;
        final CorrelationId corr;
        final int field;

        CorRef(CorrelationId corr, int field, int uniqueKey) {
            this.corr = corr;
            this.field = field;
            this.uniqueKey = uniqueKey;
        }

        public String toString() {
            return this.corr.getName() + '.' + this.field;
        }

        public int hashCode() {
            return Objects.hash(this.uniqueKey, this.corr, this.field);
        }

        public boolean equals(Object o) {
            return this == o || o instanceof CorRef && this.uniqueKey == ((CorRef)o).uniqueKey && this.corr == ((CorRef)o).corr && this.field == ((CorRef)o).field;
        }

        @Override
        public int compareTo(@Nonnull CorRef o) {
            int c = this.corr.compareTo(o.corr);
            if (c != 0) {
                return c;
            }
            c = Integer.compare(this.field, o.field);
            if (c != 0) {
                return c;
            }
            return Integer.compare(this.uniqueKey, o.uniqueKey);
        }
    }

    private static class CorelMapBuilder
    extends RelShuttleImpl {
        private final int maxCnfNodeCount;
        boolean hasNestedCorScope = false;
        boolean hasUnsupportedCorCondition = false;
        boolean hasAggregateNode = false;
        boolean hasOverNode = false;
        final SortedMap<CorrelationId, RelNode> mapCorToCorRel = new TreeMap<CorrelationId, RelNode>();
        final SortedSetMultimap<RelNode, CorRef> mapRefRelToCorRef = Multimaps.newSortedSetMultimap(new HashMap(), new Supplier<TreeSet<CorRef>>(){

            @Override
            public TreeSet<CorRef> get() {
                Bug.upgrade("use MultimapBuilder when we're on Guava-16");
                return Sets.newTreeSet();
            }
        });
        final Map<RexFieldAccess, CorRef> mapFieldAccessToCorVar = new HashMap<RexFieldAccess, CorRef>();
        final Map<RelNode, Set<CorrelationId>> mapSubQueryNodeToCorSet = new HashMap<RelNode, Set<CorrelationId>>();
        int corrIdGenerator = 0;
        final Deque<RelNode> corNodeStack = new ArrayDeque<RelNode>();

        public CorelMapBuilder(int maxCnfNodeCount) {
            this.maxCnfNodeCount = maxCnfNodeCount;
        }

        CorelMap build(RelNode ... rels) {
            for (RelNode rel : rels) {
                SubQueryDecorrelator.stripHep(rel).accept(this);
            }
            return CorelMap.of(this.mapRefRelToCorRef, this.mapCorToCorRel, this.mapSubQueryNodeToCorSet);
        }

        @Override
        protected RelNode visitChild(RelNode parent, int i, RelNode input) {
            return super.visitChild(parent, i, SubQueryDecorrelator.stripHep(input));
        }

        @Override
        public RelNode visit(LogicalCorrelate correlate) {
            this.checkCorConditionOfInput(correlate.getLeft());
            this.checkCorConditionOfInput(correlate.getRight());
            this.visitChild(correlate, 0, correlate.getLeft());
            this.visitChild(correlate, 1, correlate.getRight());
            return correlate;
        }

        @Override
        public RelNode visit(LogicalJoin join) {
            switch (join.getJoinType()) {
                case LEFT: {
                    this.checkCorConditionOfInput(join.getRight());
                    break;
                }
                case RIGHT: {
                    this.checkCorConditionOfInput(join.getLeft());
                    break;
                }
                case FULL: {
                    this.checkCorConditionOfInput(join.getLeft());
                    this.checkCorConditionOfInput(join.getRight());
                    break;
                }
            }
            boolean hasSubQuery = RexUtil.SubQueryFinder.find(join.getCondition()) != null;
            try {
                if (!this.corNodeStack.isEmpty()) {
                    this.mapSubQueryNodeToCorSet.put(join, this.corNodeStack.peek().getVariablesSet());
                }
                if (hasSubQuery) {
                    this.corNodeStack.push(join);
                }
                this.checkCorCondition(join);
                join.getCondition().accept(this.rexVisitor(join));
            }
            finally {
                if (hasSubQuery) {
                    this.corNodeStack.pop();
                }
            }
            this.visitChild(join, 0, join.getLeft());
            this.visitChild(join, 1, join.getRight());
            return join;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public RelNode visit(LogicalFilter filter) {
            boolean hasSubQuery = RexUtil.SubQueryFinder.find(filter.getCondition()) != null;
            try {
                if (!this.corNodeStack.isEmpty()) {
                    this.mapSubQueryNodeToCorSet.put(filter, this.corNodeStack.peek().getVariablesSet());
                }
                if (hasSubQuery) {
                    this.corNodeStack.push(filter);
                }
                this.checkCorCondition(filter);
                filter.getCondition().accept(this.rexVisitor(filter));
                for (CorrelationId correlationId : filter.getVariablesSet()) {
                    this.mapCorToCorRel.put(correlationId, filter);
                }
            }
            finally {
                if (hasSubQuery) {
                    this.corNodeStack.pop();
                }
            }
            return super.visit(filter);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public RelNode visit(LogicalProject project) {
            this.hasOverNode = RexOver.containsOver(project.getProjects(), null);
            boolean hasSubQuery = RexUtil.SubQueryFinder.find(project.getProjects()) != null;
            try {
                if (!this.corNodeStack.isEmpty()) {
                    this.mapSubQueryNodeToCorSet.put(project, this.corNodeStack.peek().getVariablesSet());
                }
                if (hasSubQuery) {
                    this.corNodeStack.push(project);
                }
                this.checkCorCondition(project);
                for (RexNode node : project.getProjects()) {
                    node.accept(this.rexVisitor(project));
                }
            }
            finally {
                if (hasSubQuery) {
                    this.corNodeStack.pop();
                }
            }
            return super.visit(project);
        }

        @Override
        public RelNode visit(LogicalAggregate aggregate) {
            this.hasAggregateNode = true;
            return super.visit(aggregate);
        }

        @Override
        public RelNode visit(LogicalUnion union) {
            this.checkCorConditionOfSetOpInputs(union);
            return super.visit(union);
        }

        @Override
        public RelNode visit(LogicalMinus minus) {
            this.checkCorConditionOfSetOpInputs(minus);
            return super.visit(minus);
        }

        @Override
        public RelNode visit(LogicalIntersect intersect) {
            this.checkCorConditionOfSetOpInputs(intersect);
            return super.visit(intersect);
        }

        private void checkCorCondition(LogicalFilter filter) {
            if (this.mapSubQueryNodeToCorSet.containsKey(filter) && !this.hasUnsupportedCorCondition) {
                ArrayList corConditions = new ArrayList();
                ArrayList unsupportedCorConditions = new ArrayList();
                SubQueryDecorrelator.analyzeCorConditions(this.mapSubQueryNodeToCorSet.get(filter), filter.getCondition(), filter.getCluster().getRexBuilder(), this.maxCnfNodeCount, corConditions, new ArrayList(), unsupportedCorConditions);
                if (!unsupportedCorConditions.isEmpty()) {
                    this.hasUnsupportedCorCondition = true;
                } else if (!corConditions.isEmpty()) {
                    boolean hasNonEquals = false;
                    for (RexNode node : corConditions) {
                        if (!(node instanceof RexCall) || ((RexCall)node).getOperator() == SqlStdOperatorTable.EQUALS) continue;
                        hasNonEquals = true;
                        break;
                    }
                    this.hasUnsupportedCorCondition = hasNonEquals && (this.hasAggregateNode || this.hasOverNode);
                }
            }
        }

        private void checkCorCondition(LogicalJoin join) {
            if (!this.hasUnsupportedCorCondition) {
                join.getCondition().accept(new RexVisitorImpl<Void>(true){

                    @Override
                    public Void visitCorrelVariable(RexCorrelVariable correlVariable) {
                        hasUnsupportedCorCondition = true;
                        return (Void)super.visitCorrelVariable(correlVariable);
                    }
                });
            }
        }

        private void checkCorCondition(LogicalProject project) {
            if (!this.hasUnsupportedCorCondition) {
                for (RexNode node : project.getProjects()) {
                    node.accept(new RexVisitorImpl<Void>(true){

                        @Override
                        public Void visitCorrelVariable(RexCorrelVariable correlVariable) {
                            hasUnsupportedCorCondition = true;
                            return (Void)super.visitCorrelVariable(correlVariable);
                        }
                    });
                }
            }
        }

        private void checkCorConditionOfInput(RelNode input) {
            RelShuttleImpl shuttle = new RelShuttleImpl(){
                final RexVisitor<Void> visitor = new RexVisitorImpl<Void>(true){

                    @Override
                    public Void visitCorrelVariable(RexCorrelVariable correlVariable) {
                        hasUnsupportedCorCondition = true;
                        return (Void)super.visitCorrelVariable(correlVariable);
                    }
                };

                @Override
                public RelNode visit(LogicalFilter filter) {
                    filter.getCondition().accept(this.visitor);
                    return super.visit(filter);
                }

                @Override
                public RelNode visit(LogicalProject project) {
                    for (RexNode rex : project.getProjects()) {
                        rex.accept(this.visitor);
                    }
                    return super.visit(project);
                }

                @Override
                public RelNode visit(LogicalJoin join) {
                    join.getCondition().accept(this.visitor);
                    return super.visit(join);
                }
            };
            input.accept(shuttle);
        }

        private void checkCorConditionOfSetOpInputs(SetOp setOp) {
            for (RelNode child : setOp.getInputs()) {
                this.checkCorConditionOfInput(child);
            }
        }

        private RexVisitorImpl<Void> rexVisitor(final RelNode rel) {
            return new RexVisitorImpl<Void>(true){

                @Override
                public Void visitSubQuery(RexSubQuery subQuery) {
                    hasAggregateNode = false;
                    hasOverNode = false;
                    subQuery.rel.accept(this);
                    return (Void)super.visitSubQuery(subQuery);
                }

                @Override
                public Void visitFieldAccess(RexFieldAccess fieldAccess) {
                    RexNode ref = fieldAccess.getReferenceExpr();
                    if (ref instanceof RexCorrelVariable) {
                        RexCorrelVariable var = (RexCorrelVariable)ref;
                        if (!hasUnsupportedCorCondition) {
                            boolean bl = hasUnsupportedCorCondition = !mapSubQueryNodeToCorSet.containsKey(rel);
                        }
                        if (!hasNestedCorScope && mapSubQueryNodeToCorSet.containsKey(rel)) {
                            boolean bl = hasNestedCorScope = !mapSubQueryNodeToCorSet.get(rel).contains(var.id);
                        }
                        if (mapFieldAccessToCorVar.containsKey(fieldAccess)) {
                            mapRefRelToCorRef.put(rel, mapFieldAccessToCorVar.get(fieldAccess));
                        } else {
                            CorRef correlation = new CorRef(var.id, fieldAccess.getField().getIndex(), corrIdGenerator++);
                            mapFieldAccessToCorVar.put(fieldAccess, correlation);
                            mapRefRelToCorRef.put(rel, correlation);
                        }
                    }
                    return (Void)super.visitFieldAccess(fieldAccess);
                }
            };
        }
    }

    public static class SubQueryRelDecorrelator
    implements ReflectiveVisitor {
        private final CorelMap cm;
        private final RelBuilder relBuilder;
        private final RexBuilder rexBuilder;
        private final ReflectUtil.MethodDispatcher<Frame> dispatcher = ReflectUtil.createMethodDispatcher(Frame.class, this, "decorrelateRel", RelNode.class, new Class[0]);
        private final int maxCnfNodeCount;

        SubQueryRelDecorrelator(CorelMap cm, RelBuilder relBuilder, RexBuilder rexBuilder, int maxCnfNodeCount) {
            this.cm = cm;
            this.relBuilder = relBuilder;
            this.rexBuilder = rexBuilder;
            this.maxCnfNodeCount = maxCnfNodeCount;
        }

        Frame getInvoke(RelNode r) {
            return this.dispatcher.invoke(r);
        }

        public Frame decorrelateRel(LogicalProject rel) {
            int newPos;
            RelNode oldInput = rel.getInput();
            Frame frame = this.getInvoke(oldInput);
            if (frame == null) {
                return null;
            }
            List<RexNode> oldProjects = rel.getProjects();
            List<RelDataTypeField> relOutput = rel.getRowType().getFieldList();
            RelNode newInput = frame.r;
            ArrayList<Pair<RexNode, String>> projects = new ArrayList<Pair<RexNode, String>>();
            assert (!this.cm.mapRefRelToCorRef.containsKey(rel));
            HashMap<Integer, Integer> mapInputToOutput = new HashMap<Integer, Integer>();
            HashMap<Integer, Integer> mapOldToNewOutputs = new HashMap<Integer, Integer>();
            for (newPos = 0; newPos < oldProjects.size(); ++newPos) {
                RexNode project = SubQueryDecorrelator.adjustInputRefs(oldProjects.get(newPos), frame.oldToNewOutputs, newInput.getRowType());
                projects.add(newPos, Pair.of(project, relOutput.get(newPos).getName()));
                mapOldToNewOutputs.put(newPos, newPos);
                if (!(project instanceof RexInputRef)) continue;
                mapInputToOutput.put(((RexInputRef)project).getIndex(), newPos);
            }
            if (frame.c != null) {
                ImmutableBitSet corInputIndices = RelOptUtil.InputFinder.bits(frame.c);
                RelDataType inputRowType = newInput.getRowType();
                for (int inputIndex : corInputIndices.toList()) {
                    if (mapInputToOutput.containsKey(inputIndex)) continue;
                    projects.add(newPos, Pair.of(RexInputRef.of(inputIndex, inputRowType), inputRowType.getFieldNames().get(inputIndex)));
                    mapInputToOutput.put(inputIndex, newPos);
                    ++newPos;
                }
            }
            RelNode newProject = RelOptUtil.createProject(newInput, projects, false);
            RexNode newCorCondition = frame.c != null ? SubQueryDecorrelator.adjustInputRefs(frame.c, mapInputToOutput, newProject.getRowType()) : null;
            return new Frame(rel, newProject, newCorCondition, mapOldToNewOutputs);
        }

        public Frame decorrelateRel(LogicalFilter rel) {
            RelNode oldInput = rel.getInput();
            Frame frame = this.getInvoke(oldInput);
            if (frame == null) {
                return null;
            }
            ArrayList<RexNode> corConditions = new ArrayList<RexNode>();
            ArrayList nonCorConditions = new ArrayList();
            ArrayList unsupportedCorConditions = new ArrayList();
            SubQueryDecorrelator.analyzeCorConditions((Set)this.cm.mapSubQueryNodeToCorSet.get(rel), rel.getCondition(), this.rexBuilder, this.maxCnfNodeCount, corConditions, nonCorConditions, unsupportedCorConditions);
            assert (unsupportedCorConditions.isEmpty());
            RexNode remainingCondition = RexUtil.composeConjunction(this.rexBuilder, nonCorConditions, false);
            LogicalFilter newFilter = LogicalFilter.create(frame.r, remainingCondition, ImmutableSet.copyOf(rel.getVariablesSet()));
            if (frame.c != null) {
                corConditions.add(frame.c);
            }
            RexNode corCondition = RexUtil.composeConjunction(this.rexBuilder, corConditions, true);
            return new Frame(rel, newFilter, corCondition, frame.oldToNewOutputs);
        }

        public Frame decorrelateRel(LogicalAggregate rel) {
            assert (!this.cm.mapRefRelToCorRef.containsKey(rel));
            RelNode oldInput = rel.getInput();
            Frame frame = this.getInvoke(oldInput);
            if (frame == null) {
                return null;
            }
            RelNode newInput = frame.r;
            HashMap<Integer, Integer> mapNewInputToProjOutputs = new HashMap<Integer, Integer>();
            int oldGroupKeyCount = rel.getGroupSet().cardinality();
            ArrayList<Pair<RexNode, String>> projects = new ArrayList<Pair<RexNode, String>>();
            List<RelDataTypeField> newInputOutput = newInput.getRowType().getFieldList();
            TreeMap<Integer, RexLiteral> omittedConstants = new TreeMap<Integer, RexLiteral>();
            int newPos = 0;
            for (int i = 0; i < oldGroupKeyCount; ++i) {
                RexLiteral constant = SubQueryRelDecorrelator.projectedLiteral(newInput, i);
                if (constant != null) {
                    omittedConstants.put(i, constant);
                    continue;
                }
                int newInputPos = frame.oldToNewOutputs.get(i);
                projects.add(newPos, RexInputRef.of2(newInputPos, newInputOutput));
                mapNewInputToProjOutputs.put(newInputPos, newPos);
                ++newPos;
            }
            if (frame.c != null) {
                for (Integer index : frame.getCorInputRefIndices()) {
                    if (mapNewInputToProjOutputs.containsKey(index)) continue;
                    projects.add(newPos, RexInputRef.of2(index, newInputOutput));
                    mapNewInputToProjOutputs.put(index, newPos);
                    ++newPos;
                }
            }
            int newGroupKeyCount = newPos;
            for (int i = 0; i < newInputOutput.size(); ++i) {
                if (mapNewInputToProjOutputs.containsKey(i)) continue;
                projects.add(newPos, RexInputRef.of2(i, newInputOutput));
                mapNewInputToProjOutputs.put(i, newPos);
                ++newPos;
            }
            assert (newPos == newInputOutput.size());
            RelNode newProject = RelOptUtil.createProject(newInput, projects, false);
            RexNode newCondition = frame.c != null ? SubQueryDecorrelator.adjustInputRefs(frame.c, mapNewInputToProjOutputs, newProject.getRowType()) : null;
            HashMap<Integer, Integer> combinedMap = new HashMap<Integer, Integer>();
            HashMap<Integer, Integer> oldToNewOutputs = new HashMap<Integer, Integer>();
            List<Integer> originalGrouping = rel.getGroupSet().toList();
            for (Integer oldInputPos : frame.oldToNewOutputs.keySet()) {
                Integer newIndex = (Integer)mapNewInputToProjOutputs.get(frame.oldToNewOutputs.get(oldInputPos));
                combinedMap.put(oldInputPos, newIndex);
                if (!originalGrouping.contains(oldInputPos)) continue;
                oldToNewOutputs.put(oldInputPos, newIndex);
            }
            ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount);
            ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
            List<AggregateCall> oldAggCalls = rel.getAggCallList();
            for (AggregateCall oldAggCall : oldAggCalls) {
                List<Integer> oldAggArgs = oldAggCall.getArgList();
                ArrayList<Integer> aggArgs = new ArrayList<Integer>();
                for (int oldPos : oldAggArgs) {
                    aggArgs.add((Integer)combinedMap.get(oldPos));
                }
                int filterArg = oldAggCall.filterArg < 0 ? oldAggCall.filterArg : (Integer)combinedMap.get(oldAggCall.filterArg);
                newAggCalls.add(oldAggCall.adaptTo(newProject, aggArgs, filterArg, oldGroupKeyCount, newGroupKeyCount));
            }
            this.relBuilder.push(LogicalAggregate.create(newProject, false, newGroupSet, null, newAggCalls));
            if (!omittedConstants.isEmpty()) {
                ArrayList<RexNode> postProjects = new ArrayList<RexNode>(this.relBuilder.fields());
                for (Map.Entry entry : omittedConstants.entrySet()) {
                    postProjects.add((Integer)mapNewInputToProjOutputs.get(entry.getKey()), (RexNode)entry.getValue());
                }
                this.relBuilder.project(postProjects);
            }
            for (int i = 0; i < oldAggCalls.size(); ++i) {
                oldToNewOutputs.put(oldGroupKeyCount + i, newGroupKeyCount + omittedConstants.size() + i);
            }
            return new Frame(rel, this.relBuilder.build(), newCondition, oldToNewOutputs);
        }

        public Frame decorrelateRel(LogicalJoin rel) {
            RelNode oldLeft = rel.getInput(0);
            RelNode oldRight = rel.getInput(1);
            Frame leftFrame = this.getInvoke(oldLeft);
            Frame rightFrame = this.getInvoke(oldRight);
            if (leftFrame == null || rightFrame == null) {
                return null;
            }
            switch (rel.getJoinType()) {
                case LEFT: {
                    assert (rightFrame.c == null);
                    break;
                }
                case RIGHT: {
                    assert (leftFrame.c == null);
                    break;
                }
                case FULL: {
                    assert (leftFrame.c == null && rightFrame.c == null);
                    break;
                }
            }
            int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
            int newLeftFieldCount = leftFrame.r.getRowType().getFieldCount();
            int oldRightFieldCount = oldRight.getRowType().getFieldCount();
            assert (rel.getRowType().getFieldCount() == oldLeftFieldCount + oldRightFieldCount);
            RexNode newJoinCondition = this.adjustJoinCondition(rel.getCondition(), oldLeftFieldCount, newLeftFieldCount, leftFrame.oldToNewOutputs, rightFrame.oldToNewOutputs);
            LogicalJoin newJoin = LogicalJoin.create(leftFrame.r, rightFrame.r, rel.getHints(), newJoinCondition, rel.getVariablesSet(), rel.getJoinType());
            HashMap<Integer, Integer> mapOldToNewOutputs = new HashMap<Integer, Integer>();
            mapOldToNewOutputs.putAll(leftFrame.oldToNewOutputs);
            for (int i = 0; i < oldRightFieldCount; ++i) {
                mapOldToNewOutputs.put(i + oldLeftFieldCount, rightFrame.oldToNewOutputs.get(i) + newLeftFieldCount);
            }
            ArrayList<RexNode> corConditions = new ArrayList<RexNode>();
            if (leftFrame.c != null) {
                corConditions.add(leftFrame.c);
            }
            if (rightFrame.c != null) {
                HashMap<Integer, Integer> rightMapOldToNewOutputs = new HashMap<Integer, Integer>();
                for (int index : rightFrame.getCorInputRefIndices()) {
                    rightMapOldToNewOutputs.put(index, index + newLeftFieldCount);
                }
                RexNode newRightCondition = SubQueryDecorrelator.adjustInputRefs(rightFrame.c, rightMapOldToNewOutputs, newJoin.getRowType());
                corConditions.add(newRightCondition);
            }
            RexNode newCondition = RexUtil.composeConjunction(this.rexBuilder, corConditions, true);
            return new Frame(rel, newJoin, newCondition, mapOldToNewOutputs);
        }

        private RexNode adjustJoinCondition(RexNode joinCondition, final int oldLeftFieldCount, final int newLeftFieldCount, final Map<Integer, Integer> leftOldToNewOutputs, final Map<Integer, Integer> rightOldToNewOutputs) {
            return joinCondition.accept(new RexShuttle(){

                @Override
                public RexNode visitInputRef(RexInputRef inputRef) {
                    int newIndex;
                    int oldIndex = inputRef.getIndex();
                    if (oldIndex < oldLeftFieldCount) {
                        assert (leftOldToNewOutputs.containsKey(oldIndex));
                        newIndex = (Integer)leftOldToNewOutputs.get(oldIndex);
                    } else {
                        assert (rightOldToNewOutputs.containsKey(oldIndex -= oldLeftFieldCount));
                        newIndex = (Integer)rightOldToNewOutputs.get(oldIndex) + newLeftFieldCount;
                    }
                    return new RexInputRef(newIndex, inputRef.getType());
                }
            });
        }

        public Frame decorrelateRel(Sort rel) {
            assert (!this.cm.mapRefRelToCorRef.containsKey(rel));
            RelNode oldInput = rel.getInput();
            Frame frame = this.getInvoke(oldInput);
            if (frame == null) {
                return null;
            }
            RelNode newInput = frame.r;
            Mappings.TargetMapping mapping = Mappings.target(frame.oldToNewOutputs, oldInput.getRowType().getFieldCount(), newInput.getRowType().getFieldCount());
            RelCollation oldCollation = rel.getCollation();
            RelCollation newCollation = RexUtil.apply(mapping, oldCollation);
            LogicalSort newSort = LogicalSort.create(newInput, newCollation, rel.offset, rel.fetch);
            return new Frame(rel, newSort, frame.c, frame.oldToNewOutputs);
        }

        public Frame decorrelateRel(Values rel) {
            return null;
        }

        public Frame decorrelateRel(LogicalCorrelate rel) {
            return this.decorrelateRel((RelNode)rel);
        }

        public Frame decorrelateRel(RelNode rel) {
            RelNode newRel = rel.copy(rel.getTraitSet(), rel.getInputs());
            if (rel.getInputs().size() > 0) {
                List<RelNode> oldInputs = rel.getInputs();
                ArrayList<RelNode> newInputs = new ArrayList<RelNode>();
                for (int i = 0; i < oldInputs.size(); ++i) {
                    Frame frame = this.getInvoke(oldInputs.get(i));
                    if (frame == null || frame.c != null) {
                        return null;
                    }
                    newInputs.add(frame.r);
                    newRel.replaceInput(i, frame.r);
                }
                if (!Util.equalShallow(oldInputs, newInputs)) {
                    newRel = rel.copy(rel.getTraitSet(), newInputs);
                }
            }
            return new Frame(rel, newRel, null, SubQueryRelDecorrelator.identityMap(rel.getRowType().getFieldCount()));
        }

        private static Map<Integer, Integer> identityMap(int count) {
            ImmutableMap.Builder<Integer, Integer> builder = ImmutableMap.builder();
            for (int i = 0; i < count; ++i) {
                builder.put(i, i);
            }
            return builder.build();
        }

        private static RexLiteral projectedLiteral(RelNode rel, int i) {
            Project project;
            RexNode node;
            if (rel instanceof Project && (node = (project = (Project)rel).getProjects().get(i)) instanceof RexLiteral) {
                return (RexLiteral)node;
            }
            return null;
        }
    }

    private static class DecorrelateRexShuttle
    extends RexShuttle {
        private final RelDataType leftRowType;
        private final RelDataType rightRowType;
        private final Set<CorrelationId> variableSet;

        private DecorrelateRexShuttle(RelDataType leftRowType, RelDataType rightRowType, Set<CorrelationId> variableSet) {
            this.leftRowType = leftRowType;
            this.rightRowType = rightRowType;
            this.variableSet = variableSet;
        }

        @Override
        public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
            RexNode ref = fieldAccess.getReferenceExpr();
            if (ref instanceof RexCorrelVariable) {
                RexCorrelVariable var = (RexCorrelVariable)ref;
                assert (this.variableSet.contains(var.id));
                RelDataTypeField field = fieldAccess.getField();
                return new RexInputRef(field.getIndex(), field.getType());
            }
            return super.visitFieldAccess(fieldAccess);
        }

        @Override
        public RexNode visitInputRef(RexInputRef inputRef) {
            assert (inputRef.getIndex() < this.rightRowType.getFieldCount());
            int newIndex = inputRef.getIndex() + this.leftRowType.getFieldCount();
            return new RexInputRef(newIndex, inputRef.getType());
        }
    }
}

