/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.planner.physical.visitor;

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.drill.exec.planner.physical.JoinPrel;
import org.apache.drill.exec.planner.physical.LateralJoinPrel;
import org.apache.drill.exec.planner.physical.Prel;
import org.apache.drill.exec.planner.physical.UnnestPrel;
import org.apache.drill.exec.planner.physical.visitor.BasePrelVisitor;
import org.apache.drill.shaded.guava.com.google.common.base.Preconditions;
import org.apache.drill.shaded.guava.com.google.common.collect.Lists;

public class AdjustOperatorsSchemaVisitor
extends BasePrelVisitor<Prel, Void, RuntimeException> {
    private Prel registeredPrel = null;
    private static AdjustOperatorsSchemaVisitor INSTANCE = new AdjustOperatorsSchemaVisitor();

    public static Prel adjustSchema(Prel prel) {
        return prel.accept(INSTANCE, null);
    }

    private void register(Prel prel) {
        this.registeredPrel = prel;
    }

    private Prel getRegisteredPrel() {
        return this.registeredPrel;
    }

    @Override
    public Prel visitPrel(Prel prel, Void value) throws RuntimeException {
        return this.preparePrel(prel, this.getChildren(prel));
    }

    public void unRegister() {
        this.registeredPrel = null;
    }

    private List<RelNode> getChildren(Prel prel, int registerForChild) {
        int ch = 0;
        ArrayList<RelNode> children = Lists.newArrayList();
        for (Prel child : prel) {
            if (ch == registerForChild) {
                this.register(prel);
            }
            child = child.accept(this, null);
            if (ch == registerForChild) {
                this.unRegister();
            }
            children.add(child);
            ++ch;
        }
        return children;
    }

    private List<RelNode> getChildren(Prel prel) {
        return this.getChildren(prel, -1);
    }

    private Prel preparePrel(Prel prel, List<RelNode> renamedNodes) {
        return (Prel)prel.copy(prel.getTraitSet(), renamedNodes);
    }

    @Override
    public Prel visitJoin(JoinPrel prel, Void value) throws RuntimeException {
        List<RelNode> children = this.getChildren(prel);
        int leftCount = children.get(0).getRowType().getFieldCount();
        ArrayList<RelNode> reNamedChildren = Lists.newArrayList();
        RelNode left = prel.getJoinInput(0, children.get(0));
        RelNode right = prel.getJoinInput(leftCount, children.get(1));
        reNamedChildren.add(left);
        reNamedChildren.add(right);
        return this.preparePrel(prel, reNamedChildren);
    }

    @Override
    public Prel visitLateral(LateralJoinPrel prel, Void value) throws RuntimeException {
        List<RelNode> children = this.getChildren(prel, 1);
        ArrayList<RelNode> reNamedChildren = new ArrayList<RelNode>();
        for (int i = 0; i < children.size(); ++i) {
            reNamedChildren.add(prel.getLateralInput(i, children.get(i)));
        }
        return this.preparePrel(prel, reNamedChildren);
    }

    @Override
    public Prel visitUnnest(UnnestPrel prel, Void value) throws RuntimeException {
        Preconditions.checkArgument(this.registeredPrel != null && this.registeredPrel instanceof LateralJoinPrel);
        Preconditions.checkArgument(prel.getRowType().getFieldCount() == 1);
        RexBuilder builder = prel.getCluster().getRexBuilder();
        LateralJoinPrel lateralJoinPrel = (LateralJoinPrel)this.getRegisteredPrel();
        int correlationIndex = lateralJoinPrel.getRequiredColumns().nextSetBit(0);
        String correlationColumnName = (String)lateralJoinPrel.getLeft().getRowType().getFieldNames().get(correlationIndex);
        RexNode corrRef = builder.makeCorrel(lateralJoinPrel.getLeft().getRowType(), lateralJoinPrel.getCorrelationId());
        RexNode fieldAccess = builder.makeFieldAccess(corrRef, correlationColumnName, false);
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<RelDataType> fieldTypes = new ArrayList<RelDataType>();
        for (RelDataTypeField field : prel.getRowType().getFieldList()) {
            fieldNames.add(correlationColumnName);
            fieldTypes.add(field.getType());
        }
        UnnestPrel unnestPrel = new UnnestPrel(prel.getCluster(), prel.getTraitSet(), prel.getCluster().getTypeFactory().createStructType(fieldTypes, fieldNames), fieldAccess);
        return unnestPrel;
    }
}

