/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.operators.hash;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.common.typeutils.GenericPairComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.ByteValueSerializer;
import org.apache.flink.api.common.typeutils.base.LongComparator;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArrayComparator;
import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.runtime.TupleComparator;
import org.apache.flink.api.java.typeutils.runtime.TupleSerializer;
import org.apache.flink.api.java.typeutils.runtime.ValueComparator;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
import org.apache.flink.runtime.operators.hash.MutableHashTable;
import org.apache.flink.types.ByteValue;
import org.apache.flink.util.MutableObjectIterator;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

class HashTableTest {
    private final TypeSerializer<Tuple2<Long, byte[]>> buildSerializer;
    private final TypeSerializer<Long> probeSerializer;
    private final TypeComparator<Tuple2<Long, byte[]>> buildComparator;
    private final TypeComparator<Long> probeComparator;
    private final TypePairComparator<Long, Tuple2<Long, byte[]>> pairComparator;

    HashTableTest() {
        TypeSerializer[] fieldSerializers = new TypeSerializer[]{LongSerializer.INSTANCE, BytePrimitiveArraySerializer.INSTANCE};
        Class<Tuple2> clazz = Tuple2.class;
        this.buildSerializer = new TupleSerializer(clazz, fieldSerializers);
        this.probeSerializer = LongSerializer.INSTANCE;
        TypeComparator[] comparators = new TypeComparator[]{new LongComparator(true)};
        TypeSerializer[] comparatorSerializers = new TypeSerializer[]{LongSerializer.INSTANCE};
        this.buildComparator = new TupleComparator(new int[]{0}, comparators, comparatorSerializers);
        this.probeComparator = new LongComparator(true);
        this.pairComparator = new TypePairComparator<Long, Tuple2<Long, byte[]>>(){
            private long ref;

            public void setReference(Long reference) {
                this.ref = reference;
            }

            public boolean equalToReference(Tuple2<Long, byte[]> candidate) {
                return (Long)candidate.f0 == this.ref;
            }

            public int compareToReference(Tuple2<Long, byte[]> candidate) {
                long x = this.ref;
                long y = (Long)candidate.f0;
                return x < y ? -1 : (x == y ? 0 : 1);
            }
        };
    }

    @Test
    void testBufferMissingForProbing() {
        try (IOManagerAsync ioMan = new IOManagerAsync();){
            int pageSize = 32768;
            int numSegments = 34;
            int numRecords = 3400;
            int recordLen = 270;
            byte[] payload = new byte[258];
            List<MemorySegment> memory = HashTableTest.getMemory(34, 32768);
            MutableHashTable table = new MutableHashTable(this.buildSerializer, this.probeSerializer, this.buildComparator, this.probeComparator, this.pairComparator, memory, (IOManager)ioMan, 16, false);
            table.open((MutableObjectIterator)new TupleBytesIterator(payload, 3400), (MutableObjectIterator)new LongIterator(10000L));
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> {
                while (table.nextRecord()) {
                    MutableObjectIterator matches = table.getBuildSideIterator();
                    while (matches.next() != null) {
                    }
                }
            }).withFailMessage("Test failed with unexpected exception", new Object[0])).hasMessageContaining("exceeded maximum number of recursions").isInstanceOf(RuntimeException.class);
            table.close();
            HashTableTest.checkNoTempFilesRemain((IOManager)ioMan);
        }
        catch (Exception e) {
            e.printStackTrace();
            Assertions.fail((String)e.getMessage());
        }
    }

    @Test
    void testSpillingFreesOnlyOverflowSegments() {
        ByteValueSerializer serializer = ByteValueSerializer.INSTANCE;
        ValueComparator buildComparator = new ValueComparator(true, ByteValue.class);
        ValueComparator probeComparator = new ValueComparator(true, ByteValue.class);
        GenericPairComparator pairComparator = new GenericPairComparator((TypeComparator)buildComparator, (TypeComparator)probeComparator);
        try (IOManagerAsync ioMan = new IOManagerAsync();){
            int pageSize = 32768;
            int numSegments = 34;
            List<MemorySegment> memory = HashTableTest.getMemory(34, 32768);
            MutableHashTable table = new MutableHashTable((TypeSerializer)serializer, (TypeSerializer)serializer, (TypeComparator)buildComparator, (TypeComparator)probeComparator, (TypePairComparator)pairComparator, memory, (IOManager)ioMan, 1, false);
            table.open((MutableObjectIterator)new ByteValueIterator(100000000L), (MutableObjectIterator)new ByteValueIterator(1L));
            table.close();
            HashTableTest.checkNoTempFilesRemain((IOManager)ioMan);
        }
        catch (Exception e) {
            e.printStackTrace();
            Assertions.fail((String)e.getMessage());
        }
    }

    @Test
    void testSpillingWhenBuildingTableWithoutOverflow() throws Exception {
        try (IOManagerAsync ioMan = new IOManagerAsync();){
            BytePrimitiveArraySerializer serializer = BytePrimitiveArraySerializer.INSTANCE;
            BytePrimitiveArrayComparator buildComparator = new BytePrimitiveArrayComparator(true);
            BytePrimitiveArrayComparator probeComparator = new BytePrimitiveArrayComparator(true);
            GenericPairComparator pairComparator = new GenericPairComparator((TypeComparator)new BytePrimitiveArrayComparator(true), (TypeComparator)new BytePrimitiveArrayComparator(true));
            int pageSize = 128;
            int numSegments = 33;
            List<MemorySegment> memory = HashTableTest.getMemory(33, 128);
            MutableHashTable table = new MutableHashTable((TypeSerializer)serializer, (TypeSerializer)serializer, (TypeComparator)buildComparator, (TypeComparator)probeComparator, (TypePairComparator)pairComparator, memory, (IOManager)ioMan, 1, false);
            int numElements = 9;
            table.open(new CombiningIterator<byte[]>(new ByteArrayIterator(numElements, 128, 0), new ByteArrayIterator(numElements, 128, 1)), new CombiningIterator<byte[]>(new ByteArrayIterator(1L, 128, 0), new ByteArrayIterator(1L, 128, 1)));
            while (table.nextRecord()) {
                MutableObjectIterator iterator = table.getBuildSideIterator();
                int counter = 0;
                while (iterator.next() != null) {
                    ++counter;
                }
                Assertions.assertThat((int)counter).isEqualTo(numElements);
            }
            table.close();
        }
    }

    private static List<MemorySegment> getMemory(int numSegments, int segmentSize) {
        ArrayList<MemorySegment> list = new ArrayList<MemorySegment>(numSegments);
        for (int i = 0; i < numSegments; ++i) {
            list.add(MemorySegmentFactory.allocateUnpooledSegment((int)segmentSize));
        }
        return list;
    }

    private static void checkNoTempFilesRemain(IOManager ioManager) {
        for (File dir : ioManager.getSpillingDirectories()) {
            for (String file : dir.list()) {
                if (file == null || file.equals(".") || file.equals("..")) continue;
                Assertions.fail((String)("hash table did not clean up temp files. remaining file: " + file));
            }
        }
    }

    private static class CombiningIterator<T>
    implements MutableObjectIterator<T> {
        private final MutableObjectIterator<T> left;
        private final MutableObjectIterator<T> right;

        public CombiningIterator(MutableObjectIterator<T> left, MutableObjectIterator<T> right) {
            this.left = left;
            this.right = right;
        }

        public T next(T reuse) throws IOException {
            Object value = this.left.next(reuse);
            if (value == null) {
                return (T)this.right.next(reuse);
            }
            return (T)value;
        }

        public T next() throws IOException {
            Object value = this.left.next();
            if (value == null) {
                return (T)this.right.next();
            }
            return (T)value;
        }
    }

    private static class ByteValueIterator
    implements MutableObjectIterator<ByteValue> {
        private final long numRecords;
        private long value = 0L;

        ByteValueIterator(long numRecords) {
            this.numRecords = numRecords;
        }

        public ByteValue next(ByteValue aLong) {
            return this.next();
        }

        public ByteValue next() {
            if (this.value++ < this.numRecords) {
                return new ByteValue(0);
            }
            return null;
        }
    }

    private static class LongIterator
    implements MutableObjectIterator<Long> {
        private final long numRecords;
        private long value = 0L;

        LongIterator(long numRecords) {
            this.numRecords = numRecords;
        }

        public Long next(Long aLong) {
            return this.next();
        }

        public Long next() {
            if (this.value < this.numRecords) {
                return this.value++;
            }
            return null;
        }
    }

    private static class ByteArrayIterator
    implements MutableObjectIterator<byte[]> {
        private final long numRecords;
        private long counter = 0L;
        private final byte[] arrayValue;

        ByteArrayIterator(long numRecords, int length, byte value) {
            this.numRecords = numRecords;
            this.arrayValue = new byte[length];
            Arrays.fill(this.arrayValue, value);
        }

        public byte[] next(byte[] array) {
            return this.next();
        }

        public byte[] next() {
            if (this.counter++ < this.numRecords) {
                return this.arrayValue;
            }
            return null;
        }
    }

    private static class TupleBytesIterator
    implements MutableObjectIterator<Tuple2<Long, byte[]>> {
        private final byte[] payload;
        private final int numRecords;
        private int count = 0;

        TupleBytesIterator(byte[] payload, int numRecords) {
            this.payload = payload;
            this.numRecords = numRecords;
        }

        public Tuple2<Long, byte[]> next(Tuple2<Long, byte[]> reuse) {
            return this.next();
        }

        public Tuple2<Long, byte[]> next() {
            if (this.count++ < this.numRecords) {
                return new Tuple2((Object)42L, (Object)this.payload);
            }
            return null;
        }
    }
}

