/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.physical.impl.aggregate;

import com.sun.codemodel.JAssignmentTarget;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JVar;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.CollectionUtils;
import org.apache.drill.common.exceptions.UserException;
import org.apache.drill.common.expression.ErrorCollectorImpl;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.expression.FunctionHolderExpression;
import org.apache.drill.common.expression.IfExpression;
import org.apache.drill.common.expression.LogicalExpression;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.logical.data.NamedExpression;
import org.apache.drill.common.map.CaseInsensitiveMap;
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.common.types.Types;
import org.apache.drill.exec.ExecConstants;
import org.apache.drill.exec.compile.sig.GeneratorMapping;
import org.apache.drill.exec.compile.sig.MappingSet;
import org.apache.drill.exec.exception.SchemaChangeException;
import org.apache.drill.exec.expr.ClassGenerator;
import org.apache.drill.exec.expr.CodeGenerator;
import org.apache.drill.exec.expr.DrillFuncHolderExpr;
import org.apache.drill.exec.expr.ExpressionTreeMaterializer;
import org.apache.drill.exec.expr.TypeHelper;
import org.apache.drill.exec.expr.ValueVectorWriteExpression;
import org.apache.drill.exec.ops.FragmentContext;
import org.apache.drill.exec.physical.config.HashAggregate;
import org.apache.drill.exec.physical.impl.aggregate.HashAggTemplate;
import org.apache.drill.exec.physical.impl.aggregate.HashAggregator;
import org.apache.drill.exec.physical.impl.common.Comparator;
import org.apache.drill.exec.physical.impl.common.HashTableConfig;
import org.apache.drill.exec.planner.physical.AggPrelBase;
import org.apache.drill.exec.record.AbstractRecordBatch;
import org.apache.drill.exec.record.BatchSchema;
import org.apache.drill.exec.record.MaterializedField;
import org.apache.drill.exec.record.RecordBatch;
import org.apache.drill.exec.record.RecordBatchMemoryManager;
import org.apache.drill.exec.record.RecordBatchSizer;
import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorContainer;
import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.record.selection.SelectionVector2;
import org.apache.drill.exec.record.selection.SelectionVector4;
import org.apache.drill.exec.util.record.RecordBatchStats;
import org.apache.drill.exec.vector.FixedWidthVector;
import org.apache.drill.exec.vector.UntypedNullHolder;
import org.apache.drill.exec.vector.UntypedNullVector;
import org.apache.drill.exec.vector.ValueVector;
import org.apache.drill.exec.vector.complex.writer.BaseWriter;
import org.apache.drill.shaded.guava.com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HashAggBatch
extends AbstractRecordBatch<HashAggregate> {
    static final Logger logger = LoggerFactory.getLogger(HashAggBatch.class);
    private HashAggregator aggregator;
    protected RecordBatch incoming;
    private LogicalExpression[] aggrExprs;
    private TypedFieldId[] groupByOutFieldIds;
    private TypedFieldId[] aggrOutFieldIds;
    private final List<Comparator> comparators;
    private BatchSchema incomingSchema;
    private boolean wasKilled;
    private List<BaseWriter.ComplexWriter> complexWriters;
    private int numGroupByExprs;
    private int numAggrExprs;
    private boolean firstBatch = true;
    private final Map<String, String> columnMapping;
    private final HashAggMemoryManager hashAggMemoryManager;
    private final GeneratorMapping UPDATE_AGGR_INSIDE = GeneratorMapping.create("setupInterior", "updateAggrValuesInternal", "resetValues", "cleanup");
    private final GeneratorMapping UPDATE_AGGR_OUTSIDE = GeneratorMapping.create("setupInterior", "outputRecordValues", "resetValues", "cleanup");
    private final MappingSet UpdateAggrValuesMapping = new MappingSet("incomingRowIdx", "outRowIdx", "htRowIdx", "incoming", "outgoing", "aggrValuesContainer", this.UPDATE_AGGR_INSIDE, this.UPDATE_AGGR_OUTSIDE, this.UPDATE_AGGR_INSIDE);

    public int getOutputRowCount() {
        return this.hashAggMemoryManager.getOutputRowCount();
    }

    public RecordBatchMemoryManager getRecordBatchMemoryManager() {
        return this.hashAggMemoryManager;
    }

    public HashAggBatch(HashAggregate popConfig, RecordBatch incoming, FragmentContext context) {
        super(popConfig, context);
        this.incoming = incoming;
        this.wasKilled = false;
        int numGrpByExprs = popConfig.getGroupByExprs().size();
        this.comparators = Lists.newArrayListWithExpectedSize(numGrpByExprs);
        for (int i = 0; i < numGrpByExprs; ++i) {
            this.comparators.add(Comparator.IS_NOT_DISTINCT_FROM);
        }
        boolean allowed = this.oContext.getAllocator().setLenient();
        logger.debug("Config: Is allocator lenient? {}", (Object)allowed);
        int configuredBatchSize = (int)context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
        long memAvail = this.oContext.getAllocator().getLimit();
        long minBatchesPerPartition = context.getOptions().getOption(ExecConstants.HASHAGG_MIN_BATCHES_PER_PARTITION_VALIDATOR);
        long minBatchesNeeded = 2L * minBatchesPerPartition;
        boolean fallbackEnabled = context.getOptions().getOption((String)"drill.exec.hashagg.fallback.enabled").bool_val;
        AggPrelBase.OperatorPhase phase = popConfig.getAggPhase();
        if (phase.is2nd() && !fallbackEnabled) {
            minBatchesNeeded *= 2L;
        }
        if ((long)configuredBatchSize > memAvail / minBatchesNeeded) {
            int reducedBatchSize = (int)(memAvail / minBatchesNeeded);
            logger.trace("Reducing configured batch size from: {} to: {}, due to Mem limit: {}", new Object[]{configuredBatchSize, reducedBatchSize, memAvail});
            configuredBatchSize = reducedBatchSize;
        }
        this.hashAggMemoryManager = new HashAggMemoryManager(configuredBatchSize);
        RecordBatchStats.printConfiguredBatchSize(this.getRecordBatchStatsContext(), configuredBatchSize);
        this.columnMapping = CaseInsensitiveMap.newHashMap();
    }

    @Override
    public VectorContainer getOutgoingContainer() {
        return this.container;
    }

    @Override
    public int getRecordCount() {
        if (this.state == AbstractRecordBatch.BatchState.DONE) {
            return 0;
        }
        return this.aggregator.getOutputCount();
    }

    @Override
    public void buildSchema() {
        RecordBatch.IterOutcome outcome = this.next(this.incoming);
        switch (outcome) {
            case NONE: {
                this.state = AbstractRecordBatch.BatchState.DONE;
                this.container.buildSchema(BatchSchema.SelectionVectorMode.NONE);
                return;
            }
        }
        this.incomingSchema = this.incoming.getSchema();
        this.createAggregator();
        this.container.allocatePrecomputedChildCount(0, 0, 0);
        if (this.incoming.getRecordCount() > 0) {
            this.hashAggMemoryManager.update();
        }
    }

    @Override
    public RecordBatch.IterOutcome innerNext() {
        HashAggregator.AggOutcome out;
        if (this.aggregator.allFlushed()) {
            return RecordBatch.IterOutcome.NONE;
        }
        if (this.aggregator.buildComplete() || this.aggregator.earlyOutput() || this.aggregator.handlingEmit()) {
            HashAggregator.AggIterOutcome aggOut = this.aggregator.outputCurrentBatch();
            switch (aggOut) {
                case AGG_NONE: {
                    return RecordBatch.IterOutcome.NONE;
                }
                case AGG_OK: {
                    return RecordBatch.IterOutcome.OK;
                }
                case AGG_EMIT: {
                    return RecordBatch.IterOutcome.EMIT;
                }
            }
            this.incoming = this.aggregator.getNewIncoming();
        }
        if (this.wasKilled) {
            this.aggregator.cleanup();
            return RecordBatch.IterOutcome.NONE;
        }
        while ((out = this.aggregator.doWork()) == HashAggregator.AggOutcome.CALL_WORK_AGAIN) {
        }
        switch (out) {
            case CLEANUP_AND_RETURN: {
                this.container.zeroVectors();
                this.aggregator.cleanup();
                this.state = AbstractRecordBatch.BatchState.DONE;
            }
            case RETURN_OUTCOME: {
                RecordBatch.IterOutcome outcome = this.aggregator.getOutcome();
                switch (outcome) {
                    case OK: 
                    case OK_NEW_SCHEMA: {
                        if (!this.firstBatch) break;
                        if (CollectionUtils.isNotEmpty(this.complexWriters)) {
                            this.container.buildSchema(BatchSchema.SelectionVectorMode.NONE);
                            outcome = RecordBatch.IterOutcome.OK_NEW_SCHEMA;
                        }
                        this.firstBatch = false;
                        break;
                    }
                }
                return outcome;
            }
            case UPDATE_AGGREGATOR: {
                throw UserException.unsupportedError().message(SchemaChangeException.schemaChanged("Hash aggregate does not support schema change", this.incomingSchema, this.incoming.getSchema()).getMessage(), new Object[0]).build(logger);
            }
        }
        throw new IllegalStateException(String.format("Unknown state %s.", new Object[]{out}));
    }

    private void createAggregator() {
        try {
            this.stats.startSetup();
            this.aggregator = this.createAggregatorInternal();
        }
        finally {
            this.stats.stopSetup();
        }
    }

    public void addComplexWriter(BaseWriter.ComplexWriter writer) {
        this.complexWriters.add(writer);
    }

    protected HashAggregator createAggregatorInternal() {
        CodeGenerator<HashAggregator> top = CodeGenerator.get(HashAggregator.TEMPLATE_DEFINITION, this.context.getOptions());
        ClassGenerator<HashAggregator> cg = top.getRoot();
        ClassGenerator<HashAggregator> cgInner = cg.getInnerGenerator("BatchHolder");
        top.plainJavaCapable(true);
        this.container.clear();
        this.numGroupByExprs = this.getKeyExpressions() != null ? this.getKeyExpressions().size() : 0;
        this.numAggrExprs = this.getValueExpressions() != null ? this.getValueExpressions().size() : 0;
        this.aggrExprs = new LogicalExpression[this.numAggrExprs];
        this.groupByOutFieldIds = new TypedFieldId[this.numGroupByExprs];
        this.aggrOutFieldIds = new TypedFieldId[this.numAggrExprs];
        ErrorCollectorImpl collector = new ErrorCollectorImpl();
        for (int i = 0; i < this.numGroupByExprs; ++i) {
            NamedExpression ne = this.getKeyExpressions().get(i);
            LogicalExpression expr = ExpressionTreeMaterializer.materialize(ne.getExpr(), this.incoming, collector, this.context.getFunctionRegistry());
            if (expr == null) continue;
            MaterializedField outputField = MaterializedField.create(ne.getRef().getAsNamePart().getName(), expr.getMajorType());
            ValueVector vv = TypeHelper.getNewVector(outputField, this.oContext.getAllocator());
            this.groupByOutFieldIds[i] = this.container.add(vv);
            this.columnMapping.put(outputField.getName(), ne.getExpr().toString().replace('`', ' ').trim());
        }
        int extraNonNullColumns = 0;
        for (int i = 0; i < this.numAggrExprs; ++i) {
            NamedExpression ne = this.getValueExpressions().get(i);
            LogicalExpression expr = ExpressionTreeMaterializer.materialize(ne.getExpr(), this.incoming, collector, this.context.getFunctionRegistry());
            if (expr instanceof IfExpression) {
                throw UserException.unsupportedError(new UnsupportedOperationException("Union type not supported in aggregate functions")).build(logger);
            }
            collector.reportErrors(logger);
            if (expr == null) continue;
            if (expr instanceof DrillFuncHolderExpr && ((DrillFuncHolderExpr)expr).getHolder().isComplexWriterFuncHolder()) {
                if (this.complexWriters == null) {
                    this.complexWriters = new ArrayList<BaseWriter.ComplexWriter>();
                } else {
                    this.complexWriters.clear();
                }
                ((DrillFuncHolderExpr)expr).setFieldReference(ne.getRef());
                MaterializedField field = MaterializedField.create(ne.getRef().getAsNamePart().getName(), UntypedNullHolder.TYPE);
                this.container.add(new UntypedNullVector(field, this.container.getAllocator()));
                this.aggrExprs[i] = expr;
                continue;
            }
            MaterializedField outputField = MaterializedField.create(ne.getRef().getAsNamePart().getName(), expr.getMajorType());
            ValueVector vv = TypeHelper.getNewVector(outputField, this.oContext.getAllocator());
            this.aggrOutFieldIds[i] = this.container.add(vv);
            this.aggrExprs[i] = new ValueVectorWriteExpression(this.aggrOutFieldIds[i], expr, true);
            if (expr instanceof FunctionHolderExpression) {
                FunctionCall functionCall;
                List<LogicalExpression> args;
                String funcName = ((FunctionHolderExpression)expr).getName();
                if (funcName.equals("sum") || funcName.equals("max") || funcName.equals("min")) {
                    ++extraNonNullColumns;
                }
                if ((args = ((FunctionCall)ne.getExpr()).args()).isEmpty()) continue;
                if (args.get(0) instanceof SchemaPath) {
                    this.columnMapping.put(outputField.getName(), ((SchemaPath)args.get(0)).getAsNamePart().getName());
                    continue;
                }
                if (!(args.get(0) instanceof FunctionCall) || !((functionCall = (FunctionCall)args.get(0)).arg(0) instanceof SchemaPath)) continue;
                this.columnMapping.put(outputField.getName(), ((SchemaPath)functionCall.arg(0)).getAsNamePart().getName());
                continue;
            }
            this.columnMapping.put(outputField.getName(), ne.getRef().getAsNamePart().getName());
        }
        this.setupUpdateAggrValues(cgInner);
        this.setupGetIndex(cg);
        cg.getBlock("resetValues")._return(JExpr.TRUE);
        this.container.buildSchema(BatchSchema.SelectionVectorMode.NONE);
        HashAggregator agg = this.context.getImplementationClass(top);
        HashTableConfig htConfig = new HashTableConfig((int)this.context.getOptions().getOption(ExecConstants.MIN_HASH_TABLE_SIZE), 0.75f, this.getKeyExpressions(), null, this.comparators);
        agg.setup((HashAggregate)this.popConfig, htConfig, this.context, this.oContext, this.incoming, this, this.aggrExprs, cgInner.getWorkspaceTypes(), cgInner, this.groupByOutFieldIds, this.container, extraNonNullColumns * 8);
        return agg;
    }

    protected List<NamedExpression> getKeyExpressions() {
        return ((HashAggregate)this.popConfig).getGroupByExprs();
    }

    protected List<NamedExpression> getValueExpressions() {
        return ((HashAggregate)this.popConfig).getAggrExprs();
    }

    private void setupUpdateAggrValues(ClassGenerator<HashAggregator> cg) {
        cg.setMappingSet(this.UpdateAggrValuesMapping);
        for (LogicalExpression aggr : this.aggrExprs) {
            cg.addExpr(aggr, ClassGenerator.BlkCreateMode.TRUE);
        }
    }

    private void setupGetIndex(ClassGenerator<HashAggregator> cg) {
        switch (this.incoming.getSchema().getSelectionVectorMode()) {
            case FOUR_BYTE: {
                JVar var = cg.declareClassField("sv4_", cg.getModel()._ref(SelectionVector4.class));
                cg.getBlock("doSetup").assign((JAssignmentTarget)var, (JExpression)JExpr.direct((String)"incoming").invoke("getSelectionVector4"));
                cg.getBlock("getVectorIndex")._return((JExpression)var.invoke("get").arg(JExpr.direct((String)"recordIndex")));
                return;
            }
            case NONE: {
                cg.getBlock("getVectorIndex")._return(JExpr.direct((String)"recordIndex"));
                return;
            }
            case TWO_BYTE: {
                JVar var = cg.declareClassField("sv2_", cg.getModel()._ref(SelectionVector2.class));
                cg.getBlock("doSetup").assign((JAssignmentTarget)var, (JExpression)JExpr.direct((String)"incoming").invoke("getSelectionVector2"));
                cg.getBlock("getVectorIndex")._return((JExpression)var.invoke("getIndex").arg(JExpr.direct((String)"recordIndex")));
                return;
            }
        }
    }

    private void updateStats() {
        this.stats.setLongStat(HashAggTemplate.Metric.INPUT_BATCH_COUNT, this.hashAggMemoryManager.getNumIncomingBatches());
        this.stats.setLongStat(HashAggTemplate.Metric.AVG_INPUT_BATCH_BYTES, this.hashAggMemoryManager.getAvgInputBatchSize());
        this.stats.setLongStat(HashAggTemplate.Metric.AVG_INPUT_ROW_BYTES, this.hashAggMemoryManager.getAvgInputRowWidth());
        this.stats.setLongStat(HashAggTemplate.Metric.INPUT_RECORD_COUNT, this.hashAggMemoryManager.getTotalInputRecords());
        this.stats.setLongStat(HashAggTemplate.Metric.OUTPUT_BATCH_COUNT, this.hashAggMemoryManager.getNumOutgoingBatches());
        this.stats.setLongStat(HashAggTemplate.Metric.AVG_OUTPUT_BATCH_BYTES, this.hashAggMemoryManager.getAvgOutputBatchSize());
        this.stats.setLongStat(HashAggTemplate.Metric.AVG_OUTPUT_ROW_BYTES, this.hashAggMemoryManager.getAvgOutputRowWidth());
        this.stats.setLongStat(HashAggTemplate.Metric.OUTPUT_RECORD_COUNT, this.hashAggMemoryManager.getTotalOutputRecords());
        RecordBatchStats.logRecordBatchStats(this.getRecordBatchStatsContext(), "incoming aggregate: count : %d, avg bytes : %d,  avg row bytes : %d, record count : %d", this.hashAggMemoryManager.getNumIncomingBatches(), this.hashAggMemoryManager.getAvgInputBatchSize(), this.hashAggMemoryManager.getAvgInputRowWidth(), this.hashAggMemoryManager.getTotalInputRecords());
        RecordBatchStats.logRecordBatchStats(this.getRecordBatchStatsContext(), "outgoing aggregate: count : %d, avg bytes : %d,  avg row bytes : %d, record count : %d", this.hashAggMemoryManager.getNumOutgoingBatches(), this.hashAggMemoryManager.getAvgOutputBatchSize(), this.hashAggMemoryManager.getAvgOutputRowWidth(), this.hashAggMemoryManager.getTotalOutputRecords());
    }

    @Override
    public void close() {
        if (this.aggregator != null) {
            this.aggregator.cleanup();
        }
        this.updateStats();
        super.close();
    }

    @Override
    protected void cancelIncoming() {
        this.wasKilled = true;
        this.incoming.cancel();
    }

    @Override
    public void dump() {
        logger.error("HashAggBatch[container={}, aggregator={}, groupByOutFieldIds={}, aggrOutFieldIds={}, incomingSchema={}, numGroupByExprs={}, numAggrExprs={}, popConfig={}]", new Object[]{this.container, this.aggregator, Arrays.toString(this.groupByOutFieldIds), Arrays.toString(this.aggrOutFieldIds), this.incomingSchema, this.numGroupByExprs, this.numAggrExprs, this.popConfig});
    }

    private class HashAggMemoryManager
    extends RecordBatchMemoryManager {
        private int valuesRowWidth;

        HashAggMemoryManager(int outputBatchSize) {
            super(outputBatchSize);
        }

        @Override
        public void update() {
            this.update(HashAggBatch.this.incoming);
        }

        @Override
        public void update(RecordBatch incomingRecordBatch) {
            this.setRecordBatchSizer(new RecordBatchSizer(incomingRecordBatch));
            int fieldId = 0;
            int newOutgoingRowWidth = 0;
            for (VectorWrapper<?> w : HashAggBatch.this.container) {
                if (w.getValueVector() instanceof FixedWidthVector) {
                    newOutgoingRowWidth += ((FixedWidthVector)w.getValueVector()).getValueWidth();
                    if (fieldId >= HashAggBatch.this.numGroupByExprs) {
                        this.valuesRowWidth += ((FixedWidthVector)w.getValueVector()).getValueWidth();
                    }
                } else {
                    int columnWidth = 0;
                    TypeProtos.MajorType type = w.getField().getType();
                    if (HashAggBatch.this.columnMapping.get(w.getValueVector().getField().getName()) == null) {
                        if (!Types.isComplex(type)) {
                            columnWidth = TypeHelper.getSize(type);
                        }
                    } else {
                        RecordBatchSizer.ColumnSize columnSize = this.getRecordBatchSizer().getColumn((String)HashAggBatch.this.columnMapping.get(w.getValueVector().getField().getName()));
                        columnWidth = columnSize == null ? TypeHelper.getSize(type) : columnSize.getAllocSizePerEntry();
                    }
                    newOutgoingRowWidth += columnWidth;
                    if (fieldId >= HashAggBatch.this.numGroupByExprs) {
                        this.valuesRowWidth += columnWidth;
                    }
                }
                ++fieldId;
            }
            if (this.updateIfNeeded(newOutgoingRowWidth)) {
                // empty if block
            }
            this.updateIncomingStats();
            RecordBatchStats.logRecordBatchStats(RecordBatchStats.RecordBatchIOType.INPUT, this.getRecordBatchSizer(), HashAggBatch.this.getRecordBatchStatsContext());
        }
    }
}

