/*
 * Decompiled with CFR 0.152.
 */
package org.apache.camel.component.tensorflow.serving;

import com.google.protobuf.GeneratedMessageV3;
import com.google.protobuf.Int64Value;
import java.util.Optional;
import org.apache.camel.Endpoint;
import org.apache.camel.Exchange;
import org.apache.camel.Message;
import org.apache.camel.component.tensorflow.serving.TensorFlowServingConfiguration;
import org.apache.camel.component.tensorflow.serving.TensorFlowServingEndpoint;
import org.apache.camel.support.DefaultProducer;
import tensorflow.serving.Classification;
import tensorflow.serving.GetModelMetadata;
import tensorflow.serving.GetModelStatus;
import tensorflow.serving.Model;
import tensorflow.serving.ModelServiceGrpc;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;
import tensorflow.serving.RegressionOuterClass;

public class TensorFlowServingProducer
extends DefaultProducer {
    private final String api;
    private final ModelServiceGrpc.ModelServiceBlockingStub modelService;
    private final PredictionServiceGrpc.PredictionServiceBlockingStub predictionService;

    public TensorFlowServingProducer(TensorFlowServingEndpoint endpoint) {
        super((Endpoint)endpoint);
        this.api = endpoint.getApi();
        this.modelService = endpoint.getModelService();
        this.predictionService = endpoint.getPredictionService();
    }

    public TensorFlowServingEndpoint getEndpoint() {
        return (TensorFlowServingEndpoint)super.getEndpoint();
    }

    public void process(Exchange exchange) throws Exception {
        GeneratedMessageV3 response = switch (this.api) {
            case "model-status" -> this.modelStatus(exchange);
            case "model-metadata" -> this.modelMetadata(exchange);
            case "classify" -> this.classify(exchange);
            case "regress" -> this.regress(exchange);
            case "predict" -> this.predict(exchange);
            default -> throw new IllegalArgumentException("Unsupported API: " + this.api);
        };
        exchange.getMessage().setBody((Object)response);
    }

    private Model.ModelSpec.Builder modelSpec(Exchange exchange) {
        Message message = exchange.getMessage();
        TensorFlowServingConfiguration configuration = this.getEndpoint().getConfiguration();
        String modelName = Optional.ofNullable((String)message.getHeader("CamelTensorFlowServingModelName", String.class)).orElse(configuration.getModelName());
        Long modelVersion = Optional.ofNullable((Long)message.getHeader("CamelTensorFlowServingModelVersion", Long.class)).orElse(configuration.getModelVersion());
        String modelVersionLabel = Optional.ofNullable((String)message.getHeader("CamelTensorFlowServingModelVersionLabel", String.class)).orElse(configuration.getModelVersionLabel());
        String signatureName = Optional.ofNullable((String)message.getHeader("CamelTensorFlowServingSignatureName", String.class)).orElse(configuration.getSignatureName());
        Model.ModelSpec.Builder builder = Model.ModelSpec.newBuilder().setName(modelName);
        if (modelVersion != null) {
            builder.setVersion(Int64Value.of(modelVersion));
        }
        if (modelVersionLabel != null) {
            builder.setVersionLabel(modelVersionLabel);
        }
        if (signatureName != null) {
            builder.setSignatureName(signatureName);
        }
        return builder;
    }

    private GetModelStatus.GetModelStatusResponse modelStatus(Exchange exchange) {
        Message message = exchange.getMessage();
        GetModelStatus.GetModelStatusRequest request = (GetModelStatus.GetModelStatusRequest)message.getBody(GetModelStatus.GetModelStatusRequest.class);
        GetModelStatus.GetModelStatusRequest.Builder builder = GetModelStatus.GetModelStatusRequest.newBuilder().setModelSpec(this.modelSpec(exchange));
        if (request != null) {
            builder.mergeFrom(request);
        }
        return this.modelService.getModelStatus(builder.build());
    }

    private GetModelMetadata.GetModelMetadataResponse modelMetadata(Exchange exchange) {
        Message message = exchange.getMessage();
        GetModelMetadata.GetModelMetadataRequest request = (GetModelMetadata.GetModelMetadataRequest)message.getBody(GetModelMetadata.GetModelMetadataRequest.class);
        GetModelMetadata.GetModelMetadataRequest.Builder builder = GetModelMetadata.GetModelMetadataRequest.newBuilder().setModelSpec(this.modelSpec(exchange)).addMetadataField("signature_def");
        if (request != null) {
            builder.mergeFrom(request);
        }
        return this.predictionService.getModelMetadata(builder.build());
    }

    private Classification.ClassificationResponse classify(Exchange exchange) {
        Message message = exchange.getMessage();
        Classification.ClassificationRequest request = (Classification.ClassificationRequest)message.getBody(Classification.ClassificationRequest.class);
        Classification.ClassificationRequest.Builder builder = Classification.ClassificationRequest.newBuilder().setModelSpec(this.modelSpec(exchange));
        if (request != null) {
            builder.mergeFrom(request);
        }
        return this.predictionService.classify(builder.build());
    }

    private RegressionOuterClass.RegressionResponse regress(Exchange exchange) {
        Message message = exchange.getMessage();
        RegressionOuterClass.RegressionRequest request = (RegressionOuterClass.RegressionRequest)message.getBody(RegressionOuterClass.RegressionRequest.class);
        RegressionOuterClass.RegressionRequest.Builder builder = RegressionOuterClass.RegressionRequest.newBuilder().setModelSpec(this.modelSpec(exchange));
        if (request != null) {
            builder.mergeFrom(request);
        }
        return this.predictionService.regress(builder.build());
    }

    private Predict.PredictResponse predict(Exchange exchange) {
        Message message = exchange.getMessage();
        Predict.PredictRequest request = (Predict.PredictRequest)message.getBody(Predict.PredictRequest.class);
        Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder().setModelSpec(this.modelSpec(exchange));
        if (request != null) {
            builder.mergeFrom(request);
        }
        return this.predictionService.predict(builder.build());
    }
}

