/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.physical.impl.filter;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.apache.drill.common.exceptions.UserException;
import org.apache.drill.common.expression.ExpressionPosition;
import org.apache.drill.common.expression.LogicalExpression;
import org.apache.drill.common.expression.PathSegment;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.exec.exception.OutOfMemoryException;
import org.apache.drill.exec.exception.SchemaChangeException;
import org.apache.drill.exec.expr.ValueVectorReadExpression;
import org.apache.drill.exec.expr.fn.impl.ValueVectorHashHelper;
import org.apache.drill.exec.ops.FragmentContext;
import org.apache.drill.exec.ops.MetricDef;
import org.apache.drill.exec.physical.config.RuntimeFilterPOP;
import org.apache.drill.exec.record.AbstractSingleRecordBatch;
import org.apache.drill.exec.record.BatchSchema;
import org.apache.drill.exec.record.RecordBatch;
import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.record.selection.SelectionVector2;
import org.apache.drill.exec.record.selection.SelectionVector4;
import org.apache.drill.exec.work.filter.BloomFilter;
import org.apache.drill.exec.work.filter.RuntimeFilterWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RuntimeFilterRecordBatch
extends AbstractSingleRecordBatch<RuntimeFilterPOP> {
    private static final Logger logger = LoggerFactory.getLogger(RuntimeFilterRecordBatch.class);
    private SelectionVector2 sv2;
    private ValueVectorHashHelper.Hash64 hash64;
    private final Map<String, Integer> field2id = new HashMap<String, Integer>();
    private List<String> toFilterFields;
    private List<BloomFilter> bloomFilters;
    private RuntimeFilterWritable current;
    private int originalRecordCount;
    private long filteredRows;
    private long appliedTimes;
    private int batchTimes;
    private boolean waited;
    private final boolean enableRFWaiting;
    private final long maxWaitingTime;
    private final long rfIdentifier;

    public RuntimeFilterRecordBatch(RuntimeFilterPOP pop, RecordBatch incoming, FragmentContext context) throws OutOfMemoryException {
        super(pop, context, incoming);
        this.enableRFWaiting = context.getOptions().getBoolean("exec.hashjoin.runtime_filter.waiting.enable");
        this.maxWaitingTime = context.getOptions().getLong("exec.hashjoin.runtime_filter.max.waiting.time");
        this.rfIdentifier = pop.getIdentifier();
    }

    @Override
    public FragmentContext getContext() {
        return this.context;
    }

    @Override
    public int getRecordCount() {
        return this.sv2.getCount();
    }

    @Override
    public SelectionVector2 getSelectionVector2() {
        return this.sv2;
    }

    @Override
    public SelectionVector4 getSelectionVector4() {
        return null;
    }

    @Override
    protected RecordBatch.IterOutcome doWork() {
        this.originalRecordCount = this.incoming.getRecordCount();
        this.sv2.setBatchActualRecordCount(this.originalRecordCount);
        this.applyRuntimeFilter();
        this.container.transferIn(this.incoming.getContainer());
        this.container.setRecordCount(this.originalRecordCount);
        this.updateStats();
        return this.getFinalOutcome(false);
    }

    @Override
    public void close() {
        if (this.sv2 != null) {
            this.sv2.clear();
        }
        super.close();
        if (this.current != null) {
            this.current.close();
        }
    }

    @Override
    protected boolean setupNewSchema() {
        if (this.sv2 != null) {
            this.sv2.clear();
        }
        this.container.clear();
        this.hash64 = null;
        switch (this.incoming.getSchema().getSelectionVectorMode()) {
            case NONE: {
                if (this.sv2 != null) break;
                this.sv2 = new SelectionVector2(this.oContext.getAllocator());
                break;
            }
            case TWO_BYTE: {
                this.sv2 = new SelectionVector2(this.oContext.getAllocator());
                break;
            }
            default: {
                throw new UnsupportedOperationException();
            }
        }
        for (VectorWrapper v : this.incoming) {
            this.container.addOrGet(v.getField(), this.callBack);
        }
        this.setupHashHelper();
        if (this.container.isSchemaChanged()) {
            this.container.buildSchema(BatchSchema.SelectionVectorMode.TWO_BYTE);
            return true;
        }
        return false;
    }

    private void setupHashHelper() {
        this.current = this.context.getRuntimeFilter(this.rfIdentifier);
        if (this.current == null) {
            return;
        }
        if (this.bloomFilters == null) {
            this.bloomFilters = this.current.unwrap();
        }
        if (this.hash64 == null) {
            ValueVectorHashHelper hashHelper = new ValueVectorHashHelper(this.incoming, this.context);
            try {
                this.toFilterFields = this.current.getRuntimeFilterBDef().getProbeFieldsList();
                ArrayList<ValueVectorReadExpression> hashFieldExps = new ArrayList<ValueVectorReadExpression>();
                ArrayList<TypedFieldId> typedFieldIds = new ArrayList<TypedFieldId>();
                for (String toFilterField : this.toFilterFields) {
                    SchemaPath schemaPath = new SchemaPath(new PathSegment.NameSegment(toFilterField), ExpressionPosition.UNKNOWN);
                    TypedFieldId typedFieldId = this.container.getValueVectorId(schemaPath);
                    int[] fieldIds = typedFieldId.getFieldIds();
                    this.field2id.put(toFilterField, fieldIds[0]);
                    typedFieldIds.add(typedFieldId);
                    ValueVectorReadExpression toHashFieldExp = new ValueVectorReadExpression(typedFieldId);
                    hashFieldExps.add(toHashFieldExp);
                }
                this.hash64 = hashHelper.getHash64(hashFieldExps.toArray(new LogicalExpression[hashFieldExps.size()]), typedFieldIds.toArray(new TypedFieldId[typedFieldIds.size()]));
            }
            catch (Exception e) {
                throw UserException.internalError(e).build(logger);
            }
        }
    }

    private void applyRuntimeFilter() {
        if (this.originalRecordCount <= 0) {
            this.sv2.setRecordCount(0);
            return;
        }
        this.current = this.context.getRuntimeFilter(this.rfIdentifier);
        this.timedWaiting();
        ++this.batchTimes;
        this.sv2.allocateNew(this.originalRecordCount);
        if (this.current == null) {
            for (int i = 0; i < this.originalRecordCount; ++i) {
                this.sv2.setIndex(i, i);
            }
            this.sv2.setRecordCount(this.originalRecordCount);
            return;
        }
        this.setupHashHelper();
        BitSet bitSet = new BitSet(this.originalRecordCount);
        int filterSize = this.toFilterFields.size();
        int svIndex = 0;
        if (filterSize == 1) {
            BloomFilter bloomFilter = this.bloomFilters.get(0);
            String fieldName = this.toFilterFields.get(0);
            int fieldId = this.field2id.get(fieldName);
            for (int rowIndex = 0; rowIndex < this.originalRecordCount; ++rowIndex) {
                long hash;
                try {
                    hash = this.hash64.hash64Code(rowIndex, 0, fieldId);
                }
                catch (SchemaChangeException e) {
                    throw new UnsupportedOperationException(e);
                }
                boolean contain = bloomFilter.find(hash);
                if (contain) {
                    this.sv2.setIndex(svIndex, rowIndex);
                    ++svIndex;
                    continue;
                }
                ++this.filteredRows;
            }
        } else {
            int i;
            for (i = 0; i < this.toFilterFields.size(); ++i) {
                BloomFilter bloomFilter = this.bloomFilters.get(i);
                String fieldName = this.toFilterFields.get(i);
                try {
                    this.computeBitSet(this.field2id.get(fieldName), bloomFilter, bitSet);
                    continue;
                }
                catch (SchemaChangeException e) {
                    throw new UnsupportedOperationException(e);
                }
            }
            for (i = 0; i < this.originalRecordCount; ++i) {
                boolean contain = bitSet.get(i);
                if (contain) {
                    this.sv2.setIndex(svIndex, i);
                    ++svIndex;
                    continue;
                }
                ++this.filteredRows;
            }
        }
        ++this.appliedTimes;
        this.sv2.setRecordCount(svIndex);
    }

    private void computeBitSet(int fieldId, BloomFilter bloomFilter, BitSet bitSet) throws SchemaChangeException {
        for (int rowIndex = 0; rowIndex < this.originalRecordCount; ++rowIndex) {
            long hash = this.hash64.hash64Code(rowIndex, 0, fieldId);
            boolean contain = bloomFilter.find(hash);
            if (contain) {
                bitSet.set(rowIndex, true);
                continue;
            }
            bitSet.set(rowIndex, false);
        }
    }

    @Override
    public void dump() {
        logger.error("RuntimeFilterRecordBatch[container={}, selectionVector={}, toFilterFields={}, originalRecordCount={}, batchSchema={}]", new Object[]{this.container, this.sv2, this.toFilterFields, this.originalRecordCount, this.incoming.getSchema()});
    }

    public void updateStats() {
        this.stats.setLongStat(Metric.FILTERED_ROWS, this.filteredRows);
        this.stats.setLongStat(Metric.APPLIED_TIMES, this.appliedTimes);
    }

    private void timedWaiting() {
        if (!this.enableRFWaiting || this.waited) {
            return;
        }
        if (this.current == null && this.batchTimes > 0) {
            this.waited = true;
            try {
                this.stats.startWait();
                this.current = this.context.getRuntimeFilter(this.rfIdentifier, this.maxWaitingTime, TimeUnit.MILLISECONDS);
            }
            finally {
                this.stats.stopWait();
            }
        }
    }

    public static enum Metric implements MetricDef
    {
        FILTERED_ROWS,
        APPLIED_TIMES;


        @Override
        public int metricId() {
            return this.ordinal();
        }
    }
}

