/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint.channel;

import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.core.memory.MemorySegmentProvider;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateSerializer;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateSerializerImpl;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.checkpoint.channel.SequentialChannelStateReader;
import org.apache.flink.runtime.checkpoint.channel.SequentialChannelStateReaderImpl;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
import org.apache.flink.runtime.io.network.partition.BufferWritingResultPartition;
import org.apache.flink.runtime.io.network.partition.NoOpBufferAvailablityListener;
import org.apache.flink.runtime.io.network.partition.ResultPartitionBuilder;
import org.apache.flink.runtime.io.network.partition.ResultSubpartition;
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.shaded.guava30.com.google.common.io.Closer;
import org.apache.flink.util.function.ThrowingConsumer;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class SequentialChannelStateReaderImplTest {
    private final ChannelStateSerializer serializer = new ChannelStateSerializerImpl();
    private final Random random = new Random();
    private final int parLevel;
    private final int statePartsPerChannel;
    private final int stateBytesPerPart;
    private final int bufferSize;
    private final int stateParLevel;
    private final int buffersPerChannel;

    @Parameterized.Parameters(name="{0}: stateParLevel={1}, statePartsPerChannel={2}, stateBytesPerPart={3},  parLevel={4}, bufferSize={5}")
    public static Object[][] parameters() {
        return new Object[][]{{"NoStateAndNoChannels", 0, 0, 0, 0, 0}, {"NoState", 0, 10, 10, 10, 10}, {"ReadPermutedStateWithEqualBuffer", 10, 10, 10, 10, 10}, {"ReadPermutedStateWithReducedBuffer", 10, 10, 10, 20, 10}, {"ReadPermutedStateWithIncreasedBuffer", 10, 10, 10, 10, 20}};
    }

    public SequentialChannelStateReaderImplTest(String desc, int stateParLevel, int statePartsPerChannel, int stateBytesPerPart, int parLevel, int bufferSize) {
        this.parLevel = parLevel;
        this.statePartsPerChannel = statePartsPerChannel;
        this.stateBytesPerPart = stateBytesPerPart;
        this.bufferSize = bufferSize;
        this.stateParLevel = stateParLevel;
        this.buffersPerChannel = Math.max(1, statePartsPerChannel * (bufferSize >= stateBytesPerPart ? 1 : stateBytesPerPart / bufferSize));
    }

    @Test
    public void testReadPermutedState() throws Exception {
        Map<InputChannelInfo, List<byte[]>> inputChannelsData = this.generateState(InputChannelInfo::new);
        Map<ResultSubpartitionInfo, List<byte[]>> resultPartitionsData = this.generateState(ResultSubpartitionInfo::new);
        SequentialChannelStateReaderImpl reader = new SequentialChannelStateReaderImpl(this.buildSnapshot(this.writePermuted(inputChannelsData, resultPartitionsData)));
        this.withResultPartitions((ThrowingConsumer<BufferWritingResultPartition[], Exception>)((ThrowingConsumer)arg_0 -> this.lambda$testReadPermutedState$0((SequentialChannelStateReader)reader, resultPartitionsData, arg_0)));
        this.withInputGates((ThrowingConsumer<InputGate[], Exception>)((ThrowingConsumer)arg_0 -> this.lambda$testReadPermutedState$1((SequentialChannelStateReader)reader, inputChannelsData, arg_0)));
    }

    private Map<ResultSubpartitionInfo, List<Buffer>> collectBuffers(BufferWritingResultPartition[] resultPartitions) throws IOException {
        HashMap<ResultSubpartitionInfo, List<Buffer>> actual = new HashMap<ResultSubpartitionInfo, List<Buffer>>();
        for (BufferWritingResultPartition resultPartition : resultPartitions) {
            for (int i = 0; i < resultPartition.getNumberOfSubpartitions(); ++i) {
                ResultSubpartitionInfo info = resultPartition.getAllPartitions()[i].getSubpartitionInfo();
                ResultSubpartitionView view = resultPartition.createSubpartitionView(info.getSubPartitionIdx(), (BufferAvailabilityListener)new NoOpBufferAvailablityListener());
                ResultSubpartition.BufferAndBacklog buffer = view.getNextBuffer();
                while (buffer != null) {
                    if (buffer.buffer().isBuffer()) {
                        actual.computeIfAbsent(info, unused -> new ArrayList()).add(buffer.buffer());
                    }
                    buffer = view.getNextBuffer();
                }
            }
        }
        return actual;
    }

    private Map<InputChannelInfo, List<Buffer>> collectBuffers(InputGate[] gates) throws Exception {
        HashMap<InputChannelInfo, List<Buffer>> actual = new HashMap<InputChannelInfo, List<Buffer>>();
        for (InputGate gate : gates) {
            Optional next = gate.pollNext();
            while (next.isPresent()) {
                if (((BufferOrEvent)next.get()).isBuffer()) {
                    actual.computeIfAbsent(((BufferOrEvent)next.get()).getChannelInfo(), unused -> new ArrayList()).add(((BufferOrEvent)next.get()).getBuffer());
                }
                next = gate.pollNext();
            }
        }
        return actual;
    }

    private void assertConsumed(InputGate[] gates) throws InterruptedException, ExecutionException {
        for (InputGate gate : gates) {
            Assert.assertTrue((boolean)gate.getStateConsumedFuture().isDone());
            gate.getStateConsumedFuture().get();
        }
    }

    private void withInputGates(ThrowingConsumer<InputGate[], Exception> action) throws Exception {
        SingleInputGate[] gates = new SingleInputGate[this.parLevel];
        int segmentsToAllocate = this.parLevel + this.parLevel * this.parLevel * this.buffersPerChannel;
        NetworkBufferPool networkBufferPool = new NetworkBufferPool(segmentsToAllocate, this.bufferSize);
        try (Closer poolCloser = Closer.create();){
            poolCloser.register(() -> ((NetworkBufferPool)networkBufferPool).destroy());
            poolCloser.register(() -> ((NetworkBufferPool)networkBufferPool).destroyAllBufferPools());
            try (Closer gateCloser = Closer.create();){
                for (int i = 0; i < this.parLevel; ++i) {
                    gates[i] = new SingleInputGateBuilder().setNumberOfChannels(this.parLevel).setSingleInputGateIndex(i).setBufferPoolFactory(networkBufferPool.createBufferPool(1, this.buffersPerChannel)).setSegmentProvider((MemorySegmentProvider)networkBufferPool).setChannelFactory((builder, gate) -> builder.setNetworkBuffersPerChannel(this.buffersPerChannel).buildRemoteRecoveredChannel((SingleInputGate)gate)).build();
                    gates[i].setup();
                    gateCloser.register(() -> ((SingleInputGate)gates[i]).close());
                }
                action.accept((Object)gates);
            }
            Assert.assertEquals((long)segmentsToAllocate, (long)networkBufferPool.getNumberOfAvailableMemorySegments());
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void withResultPartitions(ThrowingConsumer<BufferWritingResultPartition[], Exception> action) throws Exception {
        int segmentsToAllocate = this.parLevel * this.parLevel * this.buffersPerChannel;
        NetworkBufferPool networkBufferPool = new NetworkBufferPool(segmentsToAllocate, this.bufferSize);
        BufferWritingResultPartition[] resultPartitions = (BufferWritingResultPartition[])IntStream.range(0, this.parLevel).mapToObj(i -> new ResultPartitionBuilder().setResultPartitionIndex(i).setNumberOfSubpartitions(this.parLevel).setNetworkBufferPool(networkBufferPool).build()).toArray(BufferWritingResultPartition[]::new);
        try {
            for (BufferWritingResultPartition resultPartition : resultPartitions) {
                resultPartition.setup();
            }
            action.accept((Object)resultPartitions);
        }
        finally {
            for (BufferWritingResultPartition resultPartition : resultPartitions) {
                resultPartition.close();
            }
            try {
                Assert.assertEquals((long)segmentsToAllocate, (long)networkBufferPool.getNumberOfAvailableMemorySegments());
            }
            finally {
                networkBufferPool.destroyAllBufferPools();
                networkBufferPool.destroy();
            }
        }
    }

    private TaskStateSnapshot buildSnapshot(Tuple2<List<InputChannelStateHandle>, List<ResultSubpartitionStateHandle>> handles) {
        return new TaskStateSnapshot(Collections.singletonMap(new OperatorID(), OperatorSubtaskState.builder().setInputChannelState(new StateObjectCollection((Collection)handles.f0)).setResultSubpartitionState(new StateObjectCollection((Collection)handles.f1)).build()));
    }

    private <T> Map<T, List<byte[]>> generateState(BiFunction<Integer, Integer, T> descriptorCreator) {
        return IntStream.range(0, this.stateParLevel).boxed().flatMap(gateId -> IntStream.range(0, this.stateParLevel).mapToObj(channelId -> descriptorCreator.apply((Integer)gateId, channelId))).collect(Collectors.toMap(Function.identity(), this::generateSingleChannelState));
    }

    private List<byte[]> generateSingleChannelState(Object handle) {
        return IntStream.range(0, this.statePartsPerChannel).mapToObj(unused -> this.randomStateBytes()).collect(Collectors.toList());
    }

    private Tuple2<List<InputChannelStateHandle>, List<ResultSubpartitionStateHandle>> writePermuted(Map<InputChannelInfo, List<byte[]>> inputChannels, Map<ResultSubpartitionInfo, List<byte[]>> resultSubpartitions) throws IOException {
        try (ByteArrayOutputStream out = new ByteArrayOutputStream();){
            DataOutputStream dataStream = new DataOutputStream(out);
            this.serializer.writeHeader(dataStream);
            Map icOffsets = this.write(dataStream, this.permute(inputChannels));
            Map rsOffsets = this.write(dataStream, this.permute(resultSubpartitions));
            ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle("", out.toByteArray());
            Tuple2 tuple2 = Tuple2.of(icOffsets.entrySet().stream().map(e -> new InputChannelStateHandle((InputChannelInfo)e.getKey(), (StreamStateHandle)streamStateHandle, (List)e.getValue())).collect(Collectors.toList()), rsOffsets.entrySet().stream().map(e -> new ResultSubpartitionStateHandle((ResultSubpartitionInfo)e.getKey(), (StreamStateHandle)streamStateHandle, (List)e.getValue())).collect(Collectors.toList()));
            return tuple2;
        }
    }

    private <T> List<Tuple2<byte[], T>> permute(Map<T, List<byte[]>> inputChannels) {
        ArrayList<Map.Entry<T, List<byte[]>>> entries = new ArrayList<Map.Entry<T, List<byte[]>>>(inputChannels.entrySet());
        Collections.shuffle(entries);
        return entries.stream().flatMap(e -> ((List)e.getValue()).stream().map(b -> Tuple2.of((Object)b, e.getKey()))).collect(Collectors.toList());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private <T> Map<T, List<Long>> write(DataOutputStream dataStream, List<Tuple2<byte[], T>> partsPermuted) throws IOException {
        HashMap<Object, List> offsets = new HashMap<Object, List>();
        for (Tuple2<byte[], T> t2 : partsPermuted) {
            offsets.computeIfAbsent(t2.f1, unused -> new ArrayList()).add(Long.valueOf(dataStream.size()));
            NetworkBuffer networkBuffer = null;
            try {
                byte[] bytes = (byte[])t2.f0;
                networkBuffer = this.wrap(bytes);
                this.serializer.writeData(dataStream, new Buffer[]{networkBuffer});
            }
            finally {
                if (networkBuffer == null) continue;
                networkBuffer.recycleBuffer();
            }
        }
        return offsets;
    }

    private NetworkBuffer wrap(byte[] bytes) {
        return new NetworkBuffer(MemorySegmentFactory.wrap((byte[])bytes), FreeingBufferRecycler.INSTANCE, Buffer.DataType.DATA_BUFFER, bytes.length);
    }

    private byte[] randomStateBytes() {
        byte[] buf = new byte[this.stateBytesPerPart];
        this.random.nextBytes(buf);
        return buf;
    }

    private <T> void assertBuffersEquals(Map<T, List<byte[]>> expected, Map<T, List<Buffer>> actual) {
        try {
            Assert.assertEquals(SequentialChannelStateReaderImplTest.mapValues(expected, this::concat), SequentialChannelStateReaderImplTest.mapValues(actual, buffers -> this.concat(this.toBytes((List<Buffer>)buffers))));
        }
        finally {
            actual.values().stream().flatMap(Collection::stream).forEach(Buffer::recycleBuffer);
        }
    }

    private static <K, V1, V2> Map<K, V2> mapValues(Map<K, V1> map, Function<V1, V2> mapFn) {
        return map.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> mapFn.apply(e.getValue())));
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private NetworkBuffer concat(List<byte[]> list) {
        try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream();){
            for (byte[] bytes : list) {
                outputStream.write(bytes);
            }
            NetworkBuffer networkBuffer = this.wrap(outputStream.toByteArray());
            return networkBuffer;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private List<byte[]> toBytes(List<Buffer> buffers) {
        return buffers.stream().map(buffer -> {
            byte[] buf = new byte[buffer.getSize()];
            buffer.getNioBuffer(0, buffer.getSize()).get(buf, 0, buf.length);
            return buf;
        }).collect(Collectors.toList());
    }

    private /* synthetic */ void lambda$testReadPermutedState$1(SequentialChannelStateReader reader, Map inputChannelsData, InputGate[] gates) throws Exception {
        reader.readInputData(gates);
        this.assertBuffersEquals(inputChannelsData, this.collectBuffers(gates));
        this.assertConsumed(gates);
    }

    private /* synthetic */ void lambda$testReadPermutedState$0(SequentialChannelStateReader reader, Map resultPartitionsData, BufferWritingResultPartition[] resultPartitions) throws Exception {
        reader.readOutputData((ResultPartitionWriter[])resultPartitions, false);
        this.assertBuffersEquals(resultPartitionsData, this.collectBuffers(resultPartitions));
    }
}

