/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.planner.sql;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNumericLiteral;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlRandFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql2rel.SqlRexContext;
import org.apache.calcite.sql2rel.SqlRexConvertlet;
import org.apache.calcite.sql2rel.SqlRexConvertletTable;
import org.apache.calcite.sql2rel.StandardConvertletTable;
import org.apache.drill.exec.expr.fn.DrillFuncHolder;
import org.apache.drill.exec.planner.sql.Checker;
import org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper;
import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper;
import org.apache.drill.exec.planner.sql.DrillSqlOperator;
import org.apache.drill.exec.planner.sql.TypeInferenceUtils;
import org.apache.drill.exec.planner.sql.parser.DrillCalciteWrapperUtility;
import org.apache.drill.shaded.guava.com.google.common.collect.ImmutableMap;

public class DrillConvertletTable
implements SqlRexConvertletTable {
    public static final SqlRexConvertletTable INSTANCE = new DrillConvertletTable();
    private static final DrillSqlOperator CAST_HIGH_OP = new DrillSqlOperator("CastHigh", new ArrayList<DrillFuncHolder>(), Checker.getChecker(1, 1), false, opBinding -> TypeInferenceUtils.createCalciteTypeWithNullability(opBinding.getTypeFactory(), SqlTypeName.ANY, opBinding.getOperandType(0).isNullable()), false);
    private final Map<SqlOperator, SqlRexConvertlet> operatorToConvertletMap = ImmutableMap.builder().put(SqlStdOperatorTable.EXTRACT, DrillConvertletTable.extractConvertlet()).put(SqlStdOperatorTable.SQRT, DrillConvertletTable.sqrtConvertlet()).put(SqlStdOperatorTable.SUBSTRING, DrillConvertletTable.substringConvertlet()).put(SqlStdOperatorTable.COALESCE, DrillConvertletTable.coalesceConvertlet()).put(SqlStdOperatorTable.TIMESTAMP_DIFF, DrillConvertletTable.timestampDiffConvertlet()).put((SqlFunction)SqlStdOperatorTable.ROW, DrillConvertletTable.rowConvertlet()).put((SqlFunction)SqlStdOperatorTable.RAND, DrillConvertletTable.randConvertlet()).put((SqlFunction)SqlStdOperatorTable.AVG, DrillConvertletTable.avgVarianceConvertlet(DrillConvertletTable::expandAvg)).put((SqlFunction)SqlStdOperatorTable.STDDEV_POP, DrillConvertletTable.avgVarianceConvertlet(arg -> DrillConvertletTable.expandVariance(arg, true, true))).put((SqlFunction)SqlStdOperatorTable.STDDEV_SAMP, DrillConvertletTable.avgVarianceConvertlet(arg -> DrillConvertletTable.expandVariance(arg, false, true))).put((SqlFunction)SqlStdOperatorTable.STDDEV, DrillConvertletTable.avgVarianceConvertlet(arg -> DrillConvertletTable.expandVariance(arg, false, true))).put((SqlFunction)SqlStdOperatorTable.VAR_POP, DrillConvertletTable.avgVarianceConvertlet(arg -> DrillConvertletTable.expandVariance(arg, true, false))).put((SqlFunction)SqlStdOperatorTable.VAR_SAMP, DrillConvertletTable.avgVarianceConvertlet(arg -> DrillConvertletTable.expandVariance(arg, false, false))).put((SqlFunction)SqlStdOperatorTable.VARIANCE, DrillConvertletTable.avgVarianceConvertlet(arg -> DrillConvertletTable.expandVariance(arg, false, false))).build();

    private DrillConvertletTable() {
    }

    public SqlRexConvertlet get(SqlCall call) {
        if (call.getOperator() instanceof DrillCalciteSqlWrapper) {
            SqlOperator wrapper = call.getOperator();
            SqlOperator wrapped = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(call.getOperator());
            SqlRexConvertlet convertlet = this.operatorToConvertletMap.get(wrapped);
            if (convertlet != null) {
                return convertlet;
            }
            ((SqlBasicCall)call).setOperator(wrapped);
            SqlRexConvertlet sqlRexConvertlet = StandardConvertletTable.INSTANCE.get(call);
            ((SqlBasicCall)call).setOperator(wrapper);
            return sqlRexConvertlet;
        }
        SqlRexConvertlet convertlet = this.operatorToConvertletMap.get(call.getOperator());
        if (convertlet != null) {
            return convertlet;
        }
        return StandardConvertletTable.INSTANCE.get(call);
    }

    private static SqlRexConvertlet extractConvertlet() {
        return (cx, call) -> {
            RelDataType returnType;
            List operands = call.getOperandList();
            LinkedList<RexNode> exprs = new LinkedList<RexNode>();
            RelDataTypeFactory typeFactory = cx.getTypeFactory();
            for (SqlNode node : operands) {
                exprs.add(cx.convertExpression(node));
            }
            if (call.getOperator() == SqlStdOperatorTable.EXTRACT) {
                returnType = typeFactory.createSqlType(SqlTypeName.BIGINT);
            } else {
                String timeUnit = ((SqlIntervalQualifier)operands.get((int)0)).timeUnitRange.toString();
                returnType = typeFactory.createSqlType(TypeInferenceUtils.getSqlTypeNameForTimeUnit(timeUnit));
            }
            returnType = typeFactory.createTypeWithNullability(returnType, ((RexNode)exprs.get(1)).getType().isNullable());
            return cx.getRexBuilder().makeCall(returnType, call.getOperator(), exprs);
        };
    }

    private static SqlRexConvertlet sqrtConvertlet() {
        return (cx, call) -> {
            RexNode operand = cx.convertExpression(call.operand(0));
            return cx.getRexBuilder().makeCall((SqlOperator)SqlStdOperatorTable.SQRT, new RexNode[]{operand});
        };
    }

    private static SqlRexConvertlet randConvertlet() {
        return (cx, call) -> {
            List operands = call.getOperandList().stream().map(arg_0 -> ((SqlRexContext)cx).convertExpression(arg_0)).collect(Collectors.toList());
            return cx.getRexBuilder().makeCall((SqlOperator)new SqlRandFunction(){

                public boolean isDeterministic() {
                    return false;
                }
            }, operands);
        };
    }

    private static SqlRexConvertlet substringConvertlet() {
        return (cx, call) -> {
            List exprs = call.getOperandList().stream().map(arg_0 -> ((SqlRexContext)cx).convertExpression(arg_0)).collect(Collectors.toList());
            RelDataType returnType = TypeInferenceUtils.createCalciteTypeWithNullability(cx.getTypeFactory(), SqlTypeName.VARCHAR, ((RexNode)exprs.get(0)).getType().isNullable());
            return cx.getRexBuilder().makeCall(returnType, (SqlOperator)SqlStdOperatorTable.SUBSTRING, exprs);
        };
    }

    private static SqlRexConvertlet coalesceConvertlet() {
        return (cx, call) -> {
            int operandsCount = call.operandCount();
            if (operandsCount == 1) {
                return cx.convertExpression(call.operand(0));
            }
            ArrayList<RexNode> caseOperands = new ArrayList<RexNode>();
            for (int i = 0; i < operandsCount - 1; ++i) {
                RexNode caseOperand = cx.convertExpression(call.operand(i));
                caseOperands.add(cx.getRexBuilder().makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, new RexNode[]{caseOperand}));
                caseOperands.add(caseOperand);
            }
            caseOperands.add(cx.convertExpression(call.operand(operandsCount - 1)));
            return cx.getRexBuilder().makeCall((SqlOperator)SqlStdOperatorTable.CASE, caseOperands);
        };
    }

    private static SqlRexConvertlet timestampDiffConvertlet() {
        return (cx, call) -> {
            SqlIntervalQualifier unitLiteral = (SqlIntervalQualifier)call.operand(0);
            SqlIntervalQualifier qualifier = new SqlIntervalQualifier(unitLiteral.getUnit(), null, SqlParserPos.ZERO);
            List<RexNode> operands = Arrays.asList(cx.convertExpression((SqlNode)qualifier), cx.convertExpression(call.operand(1)), cx.convertExpression(call.operand(2)));
            RelDataTypeFactory typeFactory = cx.getTypeFactory();
            RelDataType returnType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), cx.getValidator().getValidatedNodeType(call.operand(1)).isNullable() || cx.getValidator().getValidatedNodeType(call.operand(2)).isNullable());
            return cx.getRexBuilder().makeCall(returnType, (SqlOperator)SqlStdOperatorTable.TIMESTAMP_DIFF, operands);
        };
    }

    private static SqlRexConvertlet rowConvertlet() {
        return (cx, call) -> {
            List args = call.getOperandList().stream().map(arg_0 -> ((SqlRexContext)cx).convertExpression(arg_0)).collect(Collectors.toList());
            return cx.getRexBuilder().makeCall((SqlOperator)SqlStdOperatorTable.ROW, args);
        };
    }

    private static SqlRexConvertlet avgVarianceConvertlet(Function<SqlNode, SqlNode> expandFunc) {
        return (cx, call) -> cx.convertExpression((SqlNode)expandFunc.apply(call.operand(0)));
    }

    private static SqlNode expandAvg(SqlNode arg) {
        SqlCall sum = DrillCalciteSqlAggFunctionWrapper.SUM.createCall(SqlParserPos.ZERO, new SqlNode[]{arg});
        SqlCall count = SqlStdOperatorTable.COUNT.createCall(SqlParserPos.ZERO, new SqlNode[]{arg});
        SqlCall sumAsDouble = CAST_HIGH_OP.createCall(SqlParserPos.ZERO, new SqlNode[]{sum});
        return SqlStdOperatorTable.DIVIDE.createCall(SqlParserPos.ZERO, new SqlNode[]{sumAsDouble, count});
    }

    private static SqlNode expandVariance(SqlNode arg, boolean biased, boolean sqrt) {
        SqlCall div;
        SqlCall denominator;
        SqlParserPos pos = SqlParserPos.ZERO;
        SqlCall castHighArg = CAST_HIGH_OP.createCall(pos, new SqlNode[]{arg});
        SqlCall argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, new SqlNode[]{castHighArg, castHighArg});
        SqlCall sumArgSquared = DrillCalciteSqlAggFunctionWrapper.SUM.createCall(pos, new SqlNode[]{argSquared});
        SqlCall sum = DrillCalciteSqlAggFunctionWrapper.SUM.createCall(pos, new SqlNode[]{castHighArg});
        SqlCall sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, new SqlNode[]{sum, sum});
        SqlCall count = SqlStdOperatorTable.COUNT.createCall(pos, new SqlNode[]{castHighArg});
        SqlCall avgSumSquared = SqlStdOperatorTable.DIVIDE.createCall(pos, new SqlNode[]{sumSquared, count});
        SqlCall diff = SqlStdOperatorTable.MINUS.createCall(pos, new SqlNode[]{sumArgSquared, avgSumSquared});
        if (biased) {
            denominator = count;
        } else {
            SqlNumericLiteral one = SqlLiteral.createExactNumeric((String)"1", (SqlParserPos)pos);
            denominator = SqlStdOperatorTable.MINUS.createCall(pos, new SqlNode[]{count, one});
        }
        SqlCall diffAsDouble = CAST_HIGH_OP.createCall(pos, new SqlNode[]{diff});
        SqlCall result = div = SqlStdOperatorTable.DIVIDE.createCall(pos, new SqlNode[]{diffAsDouble, denominator});
        if (sqrt) {
            SqlNumericLiteral half = SqlLiteral.createExactNumeric((String)"0.5", (SqlParserPos)pos);
            result = SqlStdOperatorTable.POWER.createCall(pos, new SqlNode[]{div, half});
        }
        return result;
    }
}

