/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.dashscope.protocol.okhttp;

import com.alibaba.dashscope.common.DashScopeResult;
import com.alibaba.dashscope.common.ResultCallback;
import com.alibaba.dashscope.common.Status;
import com.alibaba.dashscope.exception.ApiException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.protocol.DashScopeHeaders;
import com.alibaba.dashscope.protocol.FullDuplexClient;
import com.alibaba.dashscope.protocol.FullDuplexRequest;
import com.alibaba.dashscope.protocol.HalfDuplexClient;
import com.alibaba.dashscope.protocol.HalfDuplexRequest;
import com.alibaba.dashscope.protocol.NetworkResponse;
import com.alibaba.dashscope.protocol.Protocol;
import com.alibaba.dashscope.protocol.StreamingMode;
import com.alibaba.dashscope.protocol.WebSocketEventType;
import com.alibaba.dashscope.protocol.WebSocketResponse;
import com.alibaba.dashscope.protocol.okhttp.OkHttpClientFactory;
import com.alibaba.dashscope.utils.Constants;
import com.alibaba.dashscope.utils.JsonUtils;
import com.google.gson.JsonObject;
import io.reactivex.BackpressureStrategy;
import io.reactivex.Flowable;
import io.reactivex.FlowableEmitter;
import io.reactivex.Observable;
import io.reactivex.functions.Action;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import okhttp3.Headers;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
import okio.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OkHttpWebSocketClient
extends WebSocketListener
implements HalfDuplexClient,
FullDuplexClient {
    private static final Logger log = LoggerFactory.getLogger(OkHttpWebSocketClient.class);
    private static final int MAX_CONNECTION_TIMES = 3;
    private OkHttpClient client;
    private WebSocket webSocketClient;
    private AtomicBoolean isOpen = new AtomicBoolean(false);
    private AtomicBoolean isClosed = new AtomicBoolean(false);
    protected AtomicBoolean isFirstMessage = new AtomicBoolean(false);
    protected FlowableEmitter<DashScopeResult> responseEmitter;
    private boolean isFlattenResult;
    private FlowableEmitter<DashScopeResult> connectionEmitter;
    private AtomicBoolean passTaskStarted = new AtomicBoolean(false);

    public OkHttpWebSocketClient(OkHttpClient client, boolean passTaskStarted) {
        this.client = client;
        this.passTaskStarted.set(passTaskStarted);
    }

    private Request buildConnectionRequest(String apiKey, boolean isSecurityCheck, String workspace, Map<String, String> customHeaders, String baseWebSocketUrl) throws NoApiKeyException {
        String customUserAgent = customHeaders != null ? customHeaders.get("user-agent") : null;
        HashMap<String, String> filteredHeaders = customHeaders != null ? new HashMap<String, String>(customHeaders) : new HashMap();
        filteredHeaders.remove("user-agent");
        Request.Builder bd = new Request.Builder();
        bd.headers(Headers.of(DashScopeHeaders.buildWebSocketHeaders(apiKey, isSecurityCheck, workspace, filteredHeaders, customUserAgent)));
        String url = Constants.baseWebsocketApiUrl;
        if (baseWebSocketUrl != null) {
            url = baseWebSocketUrl;
        }
        Request request = bd.url(url).build();
        return request;
    }

    @Override
    public boolean close(int code, String reason) {
        this.isClosed.set(true);
        if (this.webSocketClient != null) {
            return this.webSocketClient.close(code, reason);
        }
        return true;
    }

    @Override
    public void cancel() {
        if (this.webSocketClient != null) {
            this.webSocketClient.cancel();
        }
    }

    private void establishWebSocketClient(String apiKey, boolean isSecurityCheck, String workspace, Map<String, String> customHeaders, String baseWebSocketUrl) {
        int reconnectionTimes = 0;
        String errorMessage = "";
        while (reconnectionTimes < 3) {
            try {
                Flowable flowable = Flowable.create(emitter -> {
                    this.connectionEmitter = emitter;
                    try {
                        this.client = OkHttpClientFactory.getOkHttpClient();
                        this.webSocketClient = this.client.newWebSocket(this.buildConnectionRequest(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl), (WebSocketListener)this);
                    }
                    catch (Throwable ex) {
                        this.connectionEmitter.onError(ex);
                    }
                }, (BackpressureStrategy)BackpressureStrategy.BUFFER);
                flowable.blockingSubscribe();
                return;
            }
            catch (Throwable ex) {
                ++reconnectionTimes;
                errorMessage = ex.getMessage();
                log.error(errorMessage);
                if (errorMessage.contains("401 Unauthorized")) break;
                if (errorMessage.contains("Can not find api-key.")) {
                    throw ex;
                }
                try {
                    Thread.sleep(10000L);
                }
                catch (InterruptedException interruptedException) {}
            }
        }
        throw new ApiException(Status.builder().code("ConnectionError").message(errorMessage).statusCode(44).build());
    }

    public void onClosed(WebSocket webSocket, int code, String reason) {
        log.debug(String.format("WebSocket %s closed: %d, %s", webSocket.toString(), code, reason));
        this.isOpen.set(false);
        this.isClosed.set(false);
    }

    public void onClosing(WebSocket webSocket, int code, String reason) {
        webSocket.close(code, null);
        log.debug(String.format("Websocket is closing, code: %s, reasion: %s", code, reason));
        if (this.responseEmitter != null && !this.responseEmitter.isCancelled()) {
            this.responseEmitter.onComplete();
        }
    }

    public void onFailure(WebSocket webSocket, Throwable t, Response response) {
        if (this.isClosed.get()) {
            log.debug("called close before but not working, close again in onFailure.");
            this.close(1013, "call closed before");
            return;
        }
        String responseBody = "";
        if (response != null) {
            try {
                responseBody = response.body().string();
            }
            catch (IOException ex) {
                log.error(ex.getMessage());
            }
        }
        String failureMessage = String.format("Websocket failure %s, cause: %s, body: %s", t.getMessage(), t.getCause(), responseBody);
        log.error(failureMessage);
        this.isOpen.set(false);
        if (this.connectionEmitter != null && !this.connectionEmitter.isCancelled()) {
            this.connectionEmitter.onError((Throwable)new Exception(failureMessage, t));
        } else if (this.responseEmitter != null && !this.responseEmitter.isCancelled()) {
            this.responseEmitter.onError((Throwable)new Exception(failureMessage, t));
        } else {
            log.error(failureMessage);
        }
    }

    public void onMessage(WebSocket webSocket, String text) {
        if (this.isClosed.get()) {
            log.debug("called close before but not working, close again in onMessage.");
            this.close(1013, "call closed before");
            return;
        }
        log.debug(text);
        if (!this.isFirstMessage.get()) {
            log.debug("Receive first package.");
            this.isFirstMessage.set(true);
        }
        try {
            WebSocketResponse response = JsonUtils.fromJson(text, WebSocketResponse.class);
            switch (response.header.event) {
                case TASK_STARTED: {
                    if (response.payload.output != null || response.payload.usage != null) {
                        this.responseEmitter.onNext(new DashScopeResult().fromResponse(Protocol.WEBSOCKET, NetworkResponse.builder().message(text).build(), this.isFlattenResult));
                        break;
                    }
                    if (this.passTaskStarted.get()) {
                        DashScopeResult start_message = (DashScopeResult)new DashScopeResult().fromResponse(Protocol.WEBSOCKET, NetworkResponse.builder().message(text).build(), this.isFlattenResult);
                        start_message.setEvent(WebSocketEventType.TASK_STARTED.getValue());
                        this.responseEmitter.onNext((Object)start_message);
                    }
                    break;
                }
                case TASK_FAILED: {
                    log.error(String.format("Receive task_failed message: %s", text));
                    Status st = Status.builder().code(response.header.code).message(response.header.message).requestId(response.header.taskId).statusCode(44).isJson(true).build();
                    if (!this.responseEmitter.isCancelled()) {
                        this.responseEmitter.onError((Throwable)new ApiException(st));
                    } else {
                        log.error(String.format("Something wrong, receive task failed message: %s", text));
                    }
                }
                case TASK_FINISHED: {
                    if (response.payload.output != null || response.payload.usage != null) {
                        this.responseEmitter.onNext(new DashScopeResult().fromResponse(Protocol.WEBSOCKET, NetworkResponse.builder().message(text).build(), this.isFlattenResult));
                    }
                    this.responseEmitter.onComplete();
                    break;
                }
                case RESULT_GENERATED: {
                    this.responseEmitter.onNext(new DashScopeResult().fromResponse(Protocol.WEBSOCKET, NetworkResponse.builder().message(text).build(), this.isFlattenResult));
                    break;
                }
                default: {
                    this.responseEmitter.onError((Throwable)new ApiException(Status.builder().code("UnknownMessage").message(String.format("Receive unknown message: %s", text)).statusCode(44).build()));
                    break;
                }
            }
        }
        catch (Throwable ex) {
            this.responseEmitter.onError((Throwable)new ApiException(Status.builder().code("MessageFormatError").message(String.format("Receive message: %s, json deserialize exception", text)).statusCode(44).build()));
        }
    }

    public void onMessage(WebSocket webSocket, ByteString bytes) {
        if (this.isClosed.get()) {
            log.debug("called close before but not working, close again in onMessage.");
            this.close(1013, "call closed before");
            return;
        }
        if (!this.isFirstMessage.get()) {
            log.debug("Receive first binary package.");
            this.isFirstMessage.set(true);
        }
        this.responseEmitter.onNext(new DashScopeResult().fromResponse(Protocol.WEBSOCKET, NetworkResponse.builder().binary(bytes.asByteBuffer()).build(), this.isFlattenResult));
    }

    public void onOpen(WebSocket webSocket, Response response) {
        if (this.isClosed.get()) {
            log.debug("called close before but not working, close again in onOpen.");
            this.close(1013, "call closed before");
            return;
        }
        this.isOpen.set(true);
        if (this.connectionEmitter != null && !this.connectionEmitter.isCancelled()) {
            this.connectionEmitter.onComplete();
        }
    }

    protected void sendTextWithRetry(String apiKey, boolean isSecurityCheck, String message, String workspace, Map<String, String> customHeaders, String baseWebSocketUrl) {
        if (!this.isOpen.get()) {
            this.establishWebSocketClient(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl);
        }
        int maxRetries = 3;
        if (this.passTaskStarted.get()) {
            log.info("Sending message: " + message);
            Boolean isOk = this.webSocketClient.send(message);
            if (!isOk.booleanValue()) {
                log.warn("Send request failed, return without retry.");
            }
            return;
        }
        for (int retryCount = 0; retryCount < maxRetries; ++retryCount) {
            log.debug("Sending message: " + message);
            Boolean isOk = this.webSocketClient.send(message);
            if (isOk.booleanValue()) break;
            this.establishWebSocketClient(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl);
            log.warn(String.format("Send request failed, the connection may closed, will reconnect and send again", new Object[0]));
            Observable.timer((long)5000L, (TimeUnit)TimeUnit.MILLISECONDS).blockingSingle();
        }
    }

    protected void sendBinaryWithRetry(String apiKey, boolean isSecurityCheck, ByteString message, String workspace, Map<String, String> customHeaders, String baseWebSocketUrl) {
        Boolean isOk;
        if (!this.isOpen.get()) {
            this.establishWebSocketClient(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl);
        }
        int maxRetries = 3;
        for (int retryCount = 0; retryCount < maxRetries && !(isOk = Boolean.valueOf(this.webSocketClient.send(message))).booleanValue(); ++retryCount) {
            this.establishWebSocketClient(apiKey, isSecurityCheck, workspace, customHeaders, baseWebSocketUrl);
            log.warn(String.format("Send request failed, the connection may closed, will reconnect and send again", new Object[0]));
            Observable.timer((long)5000L, (TimeUnit)TimeUnit.MILLISECONDS).blockingSingle();
        }
    }

    private void sendBatchRequest(HalfDuplexRequest req) {
        if (req.getWebsocketBinaryData() != null) {
            this.sendTextWithRetry(req.getApiKey(), req.isSecurityCheck(), JsonUtils.toJson(req.getStartTaskMessage()), req.getWorkspace(), req.getHeaders(), req.getBaseWebSocketUrl());
            this.sendBinaryWithRetry(req.getApiKey(), req.isSecurityCheck(), ByteString.of((ByteBuffer)req.getWebsocketBinaryData()), req.getWorkspace(), req.getHeaders(), req.getBaseWebSocketUrl());
        } else {
            this.sendTextWithRetry(req.getApiKey(), req.isSecurityCheck(), JsonUtils.toJson(req.getStartTaskMessage()), req.getWorkspace(), req.getHeaders(), req.getBaseWebSocketUrl());
        }
    }

    @Override
    public DashScopeResult send(HalfDuplexRequest req) {
        if (req.getStreamingMode() == StreamingMode.NONE || req.getStreamingMode() == StreamingMode.IN) {
            Flowable flowable = Flowable.create(emitter -> {
                this.responseEmitter = emitter;
                this.isFlattenResult = req.getIsFlatten();
            }, (BackpressureStrategy)BackpressureStrategy.BUFFER);
            flowable.subscribe().dispose();
            this.sendBatchRequest(req);
            return (DashScopeResult)flowable.blockingSingle();
        }
        throw new ApiException(Status.builder().code("Invalid call").statusCode(44).message("Please use streamOut interface of websocket.").build());
    }

    @Override
    public void send(HalfDuplexRequest req, final ResultCallback<DashScopeResult> callback) {
        if (req.getStreamingMode() != StreamingMode.NONE && req.getStreamingMode() != StreamingMode.IN) {
            throw new ApiException(Status.builder().code("Invalid call").statusCode(44).message("Please use streamOut interface of websocket.").build());
        }
        Flowable flowable = Flowable.create(emitter -> {
            this.responseEmitter = emitter;
            this.isFlattenResult = req.getIsFlatten();
        }, (BackpressureStrategy)BackpressureStrategy.BUFFER);
        flowable.subscribe().dispose();
        this.sendBatchRequest(req);
        flowable.subscribe(msg -> callback.onEvent((DashScopeResult)msg), err -> callback.onError(new ApiException((Throwable)err)), new Action(){

            public void run() throws Exception {
                callback.onComplete();
            }
        });
    }

    @Override
    public Flowable<DashScopeResult> streamOut(HalfDuplexRequest req) {
        Flowable flowable = Flowable.create(emitter -> {
            this.responseEmitter = emitter;
            this.isFlattenResult = req.getIsFlatten();
        }, (BackpressureStrategy)BackpressureStrategy.BUFFER);
        flowable.subscribe().dispose();
        this.sendBatchRequest(req);
        return flowable;
    }

    @Override
    public void streamOut(HalfDuplexRequest req, final ResultCallback<DashScopeResult> callback) {
        Flowable<DashScopeResult> flowable = this.streamOut(req);
        flowable.subscribe(msg -> callback.onEvent((DashScopeResult)msg), err -> callback.onError(new ApiException((Throwable)err)), new Action(){

            public void run() throws Exception {
                callback.onComplete();
            }
        });
    }

    protected CompletableFuture<Void> sendStreamRequest(final FullDuplexRequest req) {
        CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
            try {
                this.isFirstMessage.set(false);
                JsonObject startMessage = req.getStartTaskMessage();
                final String taskId = startMessage.get("header").getAsJsonObject().get("task_id").getAsString();
                this.sendTextWithRetry(req.getApiKey(), req.isSecurityCheck(), JsonUtils.toJson(startMessage), req.getWorkspace(), req.getHeaders(), req.getBaseWebSocketUrl());
                Flowable<Object> streamingData = req.getStreamingData();
                streamingData.subscribe(data -> {
                    try {
                        if (data instanceof String) {
                            JsonObject continueData = req.getContinueMessage((String)data, taskId);
                            this.sendTextWithRetry(req.getApiKey(), req.isSecurityCheck(), JsonUtils.toJson(continueData), req.getWorkspace(), req.getHeaders(), req.getBaseWebSocketUrl());
                        } else if (data instanceof byte[]) {
                            this.sendBinaryWithRetry(req.getApiKey(), req.isSecurityCheck(), ByteString.of((byte[])((byte[])data)), req.getWorkspace(), req.getHeaders(), req.getBaseWebSocketUrl());
                        } else if (data instanceof ByteBuffer) {
                            this.sendBinaryWithRetry(req.getApiKey(), req.isSecurityCheck(), ByteString.of((ByteBuffer)((ByteBuffer)data)), req.getWorkspace(), req.getHeaders(), req.getBaseWebSocketUrl());
                        } else {
                            JsonObject continueData = req.getContinueMessage(data, taskId);
                            this.sendTextWithRetry(req.getApiKey(), req.isSecurityCheck(), JsonUtils.toJson(continueData), req.getWorkspace(), req.getHeaders(), req.getBaseWebSocketUrl());
                        }
                    }
                    catch (Throwable ex) {
                        log.error(String.format("sendStreamData exception: %s", ex.getMessage()));
                        this.responseEmitter.onError(ex);
                    }
                }, err -> {
                    log.error(String.format("Get stream data error!", new Object[0]));
                    this.responseEmitter.onError(err);
                }, new Action(){

                    public void run() throws Exception {
                        log.debug(String.format("Stream data send completed!", new Object[0]));
                        OkHttpWebSocketClient.this.sendTextWithRetry(req.getApiKey(), req.isSecurityCheck(), JsonUtils.toJson(req.getFinishedTaskMessage(taskId)), req.getWorkspace(), req.getHeaders(), req.getBaseWebSocketUrl());
                    }
                });
            }
            catch (Throwable ex) {
                log.error(String.format("sendStreamData exception: %s", ex.getMessage()));
                this.responseEmitter.onError(ex);
            }
        });
        return future;
    }

    private void joinSendFuture(CompletableFuture<Void> future) {
        try {
            if (future.isDone()) {
                future.join();
            } else {
                future.cancel(true);
                future.join();
            }
        }
        catch (CancellationException | CompletionException ex) {
            log.error("Sending streaming data exception", (Object)ex.getMessage());
        }
    }

    @Override
    public DashScopeResult streamIn(FullDuplexRequest req) {
        Flowable flowable = Flowable.create(emitter -> {
            this.responseEmitter = emitter;
            this.isFlattenResult = req.getIsFlatten();
        }, (BackpressureStrategy)BackpressureStrategy.BUFFER);
        flowable.subscribe().dispose();
        final CompletableFuture<Void> future = this.sendStreamRequest(req);
        DashScopeResult result = (DashScopeResult)flowable.doOnError(err -> this.joinSendFuture(future)).doOnComplete(new Action(){

            public void run() throws Exception {
                OkHttpWebSocketClient.this.joinSendFuture(future);
            }
        }).blockingFirst();
        return result;
    }

    @Override
    public void streamIn(FullDuplexRequest req, ResultCallback<DashScopeResult> callback) throws NoApiKeyException, ApiException {
        DashScopeResult res = this.streamIn(req);
        callback.onEvent(res);
        callback.onComplete();
    }

    @Override
    public Flowable<DashScopeResult> duplex(FullDuplexRequest req) throws NoApiKeyException, ApiException {
        Flowable flowable = Flowable.create(emitter -> {
            this.responseEmitter = emitter;
            this.isFlattenResult = req.getIsFlatten();
        }, (BackpressureStrategy)BackpressureStrategy.BUFFER);
        flowable.subscribe().dispose();
        final CompletableFuture<Void> future = this.sendStreamRequest(req);
        return flowable.doOnError(err -> this.joinSendFuture(future)).doOnComplete(new Action(){

            public void run() throws Exception {
                OkHttpWebSocketClient.this.joinSendFuture(future);
            }
        });
    }

    @Override
    public void duplex(FullDuplexRequest req, final ResultCallback<DashScopeResult> callback) throws NoApiKeyException, ApiException {
        Flowable<DashScopeResult> flowable = this.duplex(req);
        flowable.subscribe(msg -> callback.onEvent((DashScopeResult)msg), err -> callback.onError(new ApiException((Throwable)err)), new Action(){

            public void run() throws Exception {
                callback.onComplete();
            }
        });
    }
}

