/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.gateway.rest;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.Collection;
import java.util.Collections;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.SecurityOptions;
import org.apache.flink.core.testutils.BlockerSync;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.rest.HttpMethodWrapper;
import org.apache.flink.runtime.rest.RestClient;
import org.apache.flink.runtime.rest.RestServerEndpoint;
import org.apache.flink.runtime.rest.handler.HandlerRequest;
import org.apache.flink.runtime.rest.handler.RestHandlerSpecification;
import org.apache.flink.runtime.rest.messages.EmptyMessageParameters;
import org.apache.flink.runtime.rest.messages.EmptyRequestBody;
import org.apache.flink.runtime.rest.messages.MessageHeaders;
import org.apache.flink.runtime.rest.messages.MessageParameters;
import org.apache.flink.runtime.rest.messages.RequestBody;
import org.apache.flink.runtime.rest.messages.ResponseBody;
import org.apache.flink.runtime.rest.util.RestClientException;
import org.apache.flink.runtime.rest.versioning.RestAPIVersion;
import org.apache.flink.runtime.rpc.exceptions.EndpointNotStartedException;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandler;
import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus;
import org.apache.flink.table.gateway.api.SqlGatewayService;
import org.apache.flink.table.gateway.rest.TestingSqlGatewayRestEndpoint;
import org.apache.flink.table.gateway.rest.handler.AbstractSqlGatewayRestHandler;
import org.apache.flink.table.gateway.rest.header.SqlGatewayMessageHeaders;
import org.apache.flink.table.gateway.rest.util.RestConfigUtils;
import org.apache.flink.table.gateway.rest.util.SqlGatewayRestAPIVersion;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.concurrent.ExecutorThreadFactory;
import org.apache.flink.util.concurrent.FutureUtils;
import org.assertj.core.api.Assertions;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

class SqlGatewayRestEndpointITCase {
    private static final SqlGatewayService service = null;
    private static RestServerEndpoint serverEndpoint;
    private static RestClient restClient;
    private static InetSocketAddress serverAddress;
    private static TestBadCaseHandler testHandler;
    private static TestVersionSelectionHeaders1 header1;
    private static TestVersionSelectionHeaders2 header2;
    private static TestBadCaseHeaders badCaseHeader;
    private static TestVersionHandler testVersionHandler1;
    private static TestVersionHandler testVersionHandler2;
    private static Configuration config;
    private static final Time timeout;

    SqlGatewayRestEndpointITCase() {
    }

    @BeforeEach
    void setup() throws Exception {
        header1 = new TestVersionSelectionHeaders1();
        header2 = new TestVersionSelectionHeaders2();
        testVersionHandler1 = new TestVersionHandler(service, header1);
        testVersionHandler2 = new TestVersionHandler(service, header2);
        badCaseHeader = new TestBadCaseHeaders();
        testHandler = new TestBadCaseHandler(service);
        String address = InetAddress.getLoopbackAddress().getHostAddress();
        config = RestConfigUtils.getBaseConfig(RestConfigUtils.getFlinkConfig(address, address, "0"));
        serverEndpoint = TestingSqlGatewayRestEndpoint.builder(config, service).withHandler((RestHandlerSpecification)badCaseHeader, (ChannelInboundHandler)testHandler).withHandler((RestHandlerSpecification)header1, (ChannelInboundHandler)testVersionHandler1).withHandler((RestHandlerSpecification)header2, (ChannelInboundHandler)testVersionHandler2).buildAndStart();
        restClient = new RestClient(config, (Executor)Executors.newFixedThreadPool(1, (ThreadFactory)new ExecutorThreadFactory("rest-client-thread-pool")));
        serverAddress = serverEndpoint.getServerAddress();
    }

    @AfterEach
    void stop() throws Exception {
        if (restClient != null) {
            restClient.shutdown(timeout);
            restClient = null;
        }
        if (serverEndpoint != null) {
            serverEndpoint.closeAsync().get(timeout.getSize(), timeout.getUnit());
            serverEndpoint = null;
        }
    }

    @Test
    void testSqlGatewayMessageHeaders() throws Exception {
        Assertions.assertThatThrownBy(() -> restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)header2, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)EmptyRequestBody.getInstance(), Collections.emptyList(), (RestAPIVersion)SqlGatewayRestAPIVersion.V0)).isInstanceOf(IllegalArgumentException.class);
        CompletableFuture specifiedVersionResponse = restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)header2, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)EmptyRequestBody.getInstance(), Collections.emptyList(), (RestAPIVersion)SqlGatewayRestAPIVersion.V1);
        TestResponse testResponse1 = (TestResponse)specifiedVersionResponse.get(5L, TimeUnit.SECONDS);
        Assertions.assertThat((String)testResponse1.getStatus()).isEqualTo("V1");
        CompletableFuture unspecifiedVersionResponse = restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)header2, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)EmptyRequestBody.getInstance(), Collections.emptyList());
        TestResponse testResponse2 = (TestResponse)unspecifiedVersionResponse.get(5L, TimeUnit.SECONDS);
        Assertions.assertThat((String)testResponse2.getStatus()).isEqualTo("V1");
    }

    @Test
    void testVersionSelection() throws Exception {
        CompletableFuture version1Response = restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)header1, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)EmptyRequestBody.getInstance(), Collections.emptyList(), (RestAPIVersion)SqlGatewayRestAPIVersion.V0);
        TestResponse testResponse = (TestResponse)version1Response.get(5L, TimeUnit.SECONDS);
        Assertions.assertThat((String)testResponse.getStatus()).isEqualTo("V0");
        CompletableFuture version2Response = restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)header2, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)EmptyRequestBody.getInstance(), Collections.emptyList(), (RestAPIVersion)SqlGatewayRestAPIVersion.V1);
        TestResponse testResponse2 = (TestResponse)version2Response.get(5L, TimeUnit.SECONDS);
        Assertions.assertThat((String)testResponse2.getStatus()).isEqualTo("V1");
    }

    @Test
    void testDefaultVersionRouting() throws Exception {
        Assertions.assertThat((boolean)config.getBoolean(SecurityOptions.SSL_REST_ENABLED)).isFalse();
        OkHttpClient client = new OkHttpClient();
        Request request = new Request.Builder().url(serverEndpoint.getRestBaseUrl() + header1.getTargetRestEndpointURL()).build();
        Response response = client.newCall(request).execute();
        assert (response.body() != null);
        Assertions.assertThat((String)response.body().string()).contains(new CharSequence[]{"V1"});
    }

    @Test
    void testRequestInterleaving() throws Exception {
        BlockerSync sync = new BlockerSync();
        testHandler.handlerBody = id -> {
            if (id == 1) {
                try {
                    sync.block();
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
            return CompletableFuture.completedFuture(new TestResponse(id.toString()));
        };
        CompletableFuture<TestResponse> response1 = this.sendRequestToTestHandler(new TestRequest(1));
        sync.awaitBlocker();
        CompletableFuture<TestResponse> response2 = this.sendRequestToTestHandler(new TestRequest(2));
        Assertions.assertThat((String)response2.get().status).isEqualTo("2");
        sync.releaseBlocker();
        Assertions.assertThat((String)response1.get().status).isEqualTo("1");
    }

    @Test
    void testDuplicateHandlerRegistrationIsForbidden() {
        Assertions.assertThatThrownBy(() -> {
            try (TestingSqlGatewayRestEndpoint restServerEndpoint = TestingSqlGatewayRestEndpoint.builder(config, service).withHandler((RestHandlerSpecification)header1, (ChannelInboundHandler)testHandler).withHandler((RestHandlerSpecification)badCaseHeader, (ChannelInboundHandler)testHandler).build();){
                restServerEndpoint.start();
            }
        }).isInstanceOf(FlinkRuntimeException.class);
    }

    @Test
    void testEndpointsMustBeUnique() {
        Assertions.assertThatThrownBy(() -> {
            try (TestingSqlGatewayRestEndpoint restServerEndpoint = TestingSqlGatewayRestEndpoint.builder(config, service).withHandler((RestHandlerSpecification)badCaseHeader, (ChannelInboundHandler)testHandler).withHandler((RestHandlerSpecification)badCaseHeader, (ChannelInboundHandler)testVersionHandler1).build();){
                restServerEndpoint.start();
            }
        }).isInstanceOf(FlinkRuntimeException.class);
    }

    @Test
    void testShouldWaitForHandlersWhenClosing() throws Exception {
        testHandler.closeFuture = new CompletableFuture();
        BlockerSync sync = new BlockerSync();
        testHandler.handlerBody = id -> CompletableFuture.supplyAsync(() -> {
            try {
                sync.block();
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            return new TestResponse(id.toString());
        });
        CompletableFuture closeRestServerEndpointFuture = serverEndpoint.closeAsync();
        Assertions.assertThat((CompletableFuture)closeRestServerEndpointFuture).isNotDone();
        CompletableFuture<TestResponse> request = this.sendRequestToTestHandler(new TestRequest(1));
        sync.awaitBlocker();
        testHandler.closeFuture.complete(null);
        Assertions.assertThat((CompletableFuture)closeRestServerEndpointFuture).isNotDone();
        sync.releaseBlocker();
        request.get(timeout.getSize(), timeout.getUnit());
        closeRestServerEndpointFuture.get(timeout.getSize(), timeout.getUnit());
    }

    @Test
    void testRestServerBindPort() throws Exception {
        int portRangeStart = 52300;
        int portRangeEnd = 52400;
        String address = InetAddress.getLoopbackAddress().getHostAddress();
        Configuration sqlGatewayRestEndpointConfig = RestConfigUtils.getBaseConfig(RestConfigUtils.getFlinkConfig(address, address, "52300-52400"));
        try (TestingSqlGatewayRestEndpoint serverEndpoint1 = TestingSqlGatewayRestEndpoint.builder(sqlGatewayRestEndpointConfig, service).build();
             TestingSqlGatewayRestEndpoint serverEndpoint2 = TestingSqlGatewayRestEndpoint.builder(sqlGatewayRestEndpointConfig, service).build();){
            serverEndpoint1.start();
            serverEndpoint2.start();
            Assertions.assertThat((int)Objects.requireNonNull(serverEndpoint1.getServerAddress()).getPort()).isNotEqualTo(Objects.requireNonNull(serverEndpoint2.getServerAddress()).getPort());
            Assertions.assertThat((int)serverEndpoint1.getServerAddress().getPort()).isGreaterThanOrEqualTo(52300);
            Assertions.assertThat((int)serverEndpoint1.getServerAddress().getPort()).isLessThanOrEqualTo(52400);
            Assertions.assertThat((int)serverEndpoint2.getServerAddress().getPort()).isGreaterThanOrEqualTo(52300);
            Assertions.assertThat((int)serverEndpoint2.getServerAddress().getPort()).isLessThanOrEqualTo(52400);
        }
    }

    @Test
    void testOnUnavailableRpcEndpointReturns503() {
        CompletableFuture<TestResponse> response = this.sendRequestToTestHandler(new TestRequest(3));
        Assertions.assertThatThrownBy(response::get).extracting(x -> ExceptionUtils.findThrowable((Throwable)x, RestClientException.class)).extracting(Optional::get).extracting(RestClientException::getHttpResponseStatus).isEqualTo((Object)HttpResponseStatus.SERVICE_UNAVAILABLE);
    }

    private CompletableFuture<TestResponse> sendRequestToTestHandler(TestRequest testRequest) {
        try {
            return restClient.sendRequest(serverAddress.getHostName(), serverAddress.getPort(), (MessageHeaders)badCaseHeader, (MessageParameters)EmptyMessageParameters.getInstance(), (RequestBody)testRequest);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static {
        timeout = Time.seconds((long)10L);
    }

    private static class TestVersionHandler
    extends AbstractSqlGatewayRestHandler<EmptyRequestBody, TestResponse, EmptyMessageParameters> {
        TestVersionHandler(SqlGatewayService sqlGatewayService, TestVersionSelectionHeadersBase header) {
            super(sqlGatewayService, Collections.emptyMap(), (MessageHeaders)header);
        }

        protected CompletableFuture<TestResponse> handleRequest(@Nullable SqlGatewayRestAPIVersion version, @NotNull HandlerRequest<EmptyRequestBody> request) {
            assert (version != null);
            return CompletableFuture.completedFuture(new TestResponse(version.name()));
        }
    }

    private static class TestVersionSelectionHeaders2
    extends TestVersionSelectionHeadersBase {
        private TestVersionSelectionHeaders2() {
        }

        public Collection<SqlGatewayRestAPIVersion> getSupportedAPIVersions() {
            return Collections.singleton(SqlGatewayRestAPIVersion.V1);
        }
    }

    private static class TestVersionSelectionHeaders1
    extends TestVersionSelectionHeadersBase {
        private TestVersionSelectionHeaders1() {
        }

        public Collection<SqlGatewayRestAPIVersion> getSupportedAPIVersions() {
            return Collections.singleton(SqlGatewayRestAPIVersion.V0);
        }
    }

    private static class TestVersionSelectionHeadersBase
    implements SqlGatewayMessageHeaders<EmptyRequestBody, TestResponse, EmptyMessageParameters> {
        private TestVersionSelectionHeadersBase() {
        }

        public Class<EmptyRequestBody> getRequestClass() {
            return EmptyRequestBody.class;
        }

        public HttpMethodWrapper getHttpMethod() {
            return HttpMethodWrapper.GET;
        }

        public String getTargetRestEndpointURL() {
            return "/test/select-version";
        }

        public Class<TestResponse> getResponseClass() {
            return TestResponse.class;
        }

        public HttpResponseStatus getResponseStatusCode() {
            return HttpResponseStatus.OK;
        }

        public String getDescription() {
            return null;
        }

        public EmptyMessageParameters getUnresolvedMessageParameters() {
            return EmptyMessageParameters.getInstance();
        }
    }

    private static class TestBadCaseHeaders
    implements SqlGatewayMessageHeaders<TestRequest, TestResponse, EmptyMessageParameters> {
        private TestBadCaseHeaders() {
        }

        public HttpMethodWrapper getHttpMethod() {
            return HttpMethodWrapper.POST;
        }

        public String getTargetRestEndpointURL() {
            return "/test/";
        }

        public Class<TestRequest> getRequestClass() {
            return TestRequest.class;
        }

        public Class<TestResponse> getResponseClass() {
            return TestResponse.class;
        }

        public HttpResponseStatus getResponseStatusCode() {
            return HttpResponseStatus.OK;
        }

        public String getDescription() {
            return "";
        }

        public EmptyMessageParameters getUnresolvedMessageParameters() {
            return EmptyMessageParameters.getInstance();
        }
    }

    private static class TestResponse
    implements ResponseBody {
        public final String status;

        @JsonCreator
        public TestResponse(@JsonProperty(value="status") String status) {
            this.status = status;
        }

        public String getStatus() {
            return this.status;
        }
    }

    private static class TestRequest
    implements RequestBody {
        public final int id;

        @JsonCreator
        public TestRequest(@JsonProperty(value="id") int id) {
            this.id = id;
        }
    }

    private static class TestBadCaseHandler
    extends AbstractSqlGatewayRestHandler<TestRequest, TestResponse, EmptyMessageParameters> {
        private final OneShotLatch closeLatch = new OneShotLatch();
        private CompletableFuture<Void> closeFuture = CompletableFuture.completedFuture(null);
        private Function<Integer, CompletableFuture<TestResponse>> handlerBody;

        TestBadCaseHandler(SqlGatewayService sqlGatewayService) {
            super(sqlGatewayService, Collections.emptyMap(), (MessageHeaders)badCaseHeader);
        }

        public CompletableFuture<Void> closeHandlerAsync() {
            this.closeLatch.trigger();
            return this.closeFuture;
        }

        protected CompletableFuture<TestResponse> handleRequest(@Nullable SqlGatewayRestAPIVersion version, @NotNull HandlerRequest<TestRequest> request) {
            int id = ((TestRequest)request.getRequestBody()).id;
            if (id == 3) {
                return FutureUtils.completedExceptionally((Throwable)new EndpointNotStartedException("test exception"));
            }
            return this.handlerBody.apply(id);
        }
    }
}

