/*
 * Decompiled with CFR 0.152.
 */
package apoc.ml;

import apoc.Extended;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.lang.invoke.CallSite;
import java.net.MalformedURLException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
public class VertexAI {
    private static final String BASE_URL = "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict";
    public static final String APOC_ML_VERTEXAI_URL = "apoc.ml.vertexai.url";
    public static final String DEFAULT_REGION = "us-central1";

    private static Stream<Object> executeRequest(String accessToken, String project, Map<String, Object> configuration, String defaultModel, Object inputs, String jsonPath, Collection<String> retainConfigKeys) throws JsonProcessingException, MalformedURLException {
        if (accessToken == null || accessToken.isBlank()) {
            throw new IllegalArgumentException("Access Token must not be empty");
        }
        if (project == null || project.isBlank()) {
            throw new IllegalArgumentException("Project must not be empty");
        }
        String urlTemplate = System.getProperty(APOC_ML_VERTEXAI_URL, BASE_URL);
        String model = configuration.getOrDefault("model", defaultModel).toString();
        String region = configuration.getOrDefault("region", DEFAULT_REGION).toString();
        String endpoint = String.format(urlTemplate, region, project, region, model);
        Map<String, CallSite> headers = Map.of("Content-Type", "application/json", "Accept", "application/json", "Authorization", "Bearer " + accessToken);
        Map<String, Map<String, Object>> data = Map.of("instances", inputs, "parameters", VertexAI.getParameters(configuration, retainConfigKeys));
        String payload = new ObjectMapper().writeValueAsString(data);
        return JsonUtil.loadJson((Object)endpoint, headers, (String)payload, (String)jsonPath, (boolean)true, List.of());
    }

    @Procedure(value="apoc.ml.vertexai.embedding")
    @Description(value="apoc.vertexai.embedding([texts], accessToken, project, configuration) - returns the embeddings for a given text")
    public Stream<EmbeddingResult> getEmbedding(@Name(value="texts") List<String> texts, @Name(value="accessToken") String accessToken, @Name(value="project") String project, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        List<Map> inputs = texts.stream().map(text -> Map.of("content", text)).toList();
        Stream<Object> resultStream = VertexAI.executeRequest(accessToken, project, configuration, "textembedding-gecko", inputs, "$.predictions", List.of());
        AtomicInteger ai = new AtomicInteger();
        return resultStream.flatMap(v -> ((List)v).stream()).map(m -> {
            Map embeddings = (Map)m.get("embeddings");
            int index = ai.getAndIncrement();
            return new EmbeddingResult(index, (String)texts.get(index), (List)embeddings.get("values"));
        });
    }

    @Procedure(value="apoc.ml.vertexai.completion")
    @Description(value="apoc.ml.vertexai.completion(prompt, accessToken, project, configuration) - prompts the completion API")
    public Stream<MapResult> completion(@Name(value="prompt") String prompt, @Name(value="accessToken") String accessToken, @Name(value="project") String project, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        List<Map<String, String>> input = List.of(Map.of("prompt", prompt));
        List<String> parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens");
        Stream<Object> resultStream = VertexAI.executeRequest(accessToken, project, configuration, "text-bison", input, "$.predictions", parameterKeys);
        return resultStream.flatMap(v -> ((List)v).stream()).map(v -> v).map(MapResult::new);
    }

    private static Map<String, Object> getParameters(Map<String, Object> config, Collection<String> retainKeys) {
        HashMap<String, Object> result = new HashMap<String, Object>(Map.of("temperature", config.getOrDefault("temperature", 0.3), "maxOutputTokens", config.getOrDefault("maxOutputTokens", 256), "maxDecodeSteps", config.getOrDefault("maxDecodeSteps", 200), "topP", config.getOrDefault("topP", 0.8), "topK", config.getOrDefault("topK", 40)));
        result.keySet().retainAll(retainKeys);
        return result;
    }

    @Procedure(value="apoc.ml.vertexai.chat")
    @Description(value="apoc.ml.vertexai.chat(messages, accessToken, project, configuration]) - prompts the completion API")
    public Stream<MapResult> chatCompletion(@Name(value="messages") List<Map<String, String>> messages, @Name(value="accessToken") String accessToken, @Name(value="project") String project, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration, @Name(value="context", defaultValue="") String context, @Name(value="examples", defaultValue="[]") List<Map<String, Map<String, String>>> examples) throws Exception {
        List<Map<String, List<Map<String, String>>>> inputs = List.of(Map.of("context", context, "examples", examples, "messages", messages));
        List<String> parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens");
        return VertexAI.executeRequest(accessToken, project, configuration, "chat-bison", inputs, "$.predictions", parameterKeys).flatMap(v -> ((List)v).stream()).map(v -> v).map(MapResult::new);
    }

    public static class EmbeddingResult {
        public final long index;
        public final String text;
        public final List<Double> embedding;

        public EmbeddingResult(long index, String text, List<Double> embedding) {
            this.index = index;
            this.text = text;
            this.embedding = embedding;
        }
    }
}

