/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.jlama;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.jlama.JlamaModel;
import dev.langchain4j.model.jlama.JlamaModelRegistry;
import dev.langchain4j.model.jlama.spi.JlamaLanguageModelBuilderFactory;
import dev.langchain4j.model.language.LanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.spi.ServiceHelper;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.Optional;
import java.util.UUID;

public class JlamaLanguageModel
implements LanguageModel {
    private final AbstractModel model;
    private final Float temperature;
    private final Integer maxTokens;
    private final UUID id = UUID.randomUUID();

    public JlamaLanguageModel(Path modelCachePath, String modelName, String authToken, Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory, DType workingQuantizedType, Float temperature, Integer maxTokens) {
        JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
        JlamaModel jlamaModel = (JlamaModel)RetryUtils.withRetryMappingExceptions(() -> registry.downloadModel(modelName, Optional.ofNullable(authToken)), (int)2);
        JlamaModel.Loader loader = jlamaModel.loader();
        if (quantizeModelAtRuntime != null && quantizeModelAtRuntime.booleanValue()) {
            loader = loader.quantized();
        }
        if (workingQuantizedType != null) {
            loader = loader.workingQuantizationType(workingQuantizedType);
        }
        if (threadCount != null) {
            loader = loader.threadCount(threadCount);
        }
        if (workingDirectory != null) {
            loader = loader.workingDirectory(workingDirectory);
        }
        this.model = loader.load();
        this.temperature = Float.valueOf(temperature == null ? 0.7f : temperature.floatValue());
        this.maxTokens = maxTokens == null ? this.model.getConfig().contextLength : maxTokens;
    }

    public static FinishReason toFinishReason(Generator.FinishReason reason) {
        return switch (reason) {
            case Generator.FinishReason.STOP_TOKEN -> FinishReason.STOP;
            case Generator.FinishReason.MAX_TOKENS -> FinishReason.LENGTH;
            case Generator.FinishReason.ERROR -> FinishReason.OTHER;
            case Generator.FinishReason.TOOL_CALL -> FinishReason.TOOL_EXECUTION;
            default -> throw new IllegalArgumentException("Unknown reason: " + String.valueOf(reason));
        };
    }

    public static JlamaLanguageModelBuilder builder() {
        Iterator iterator = ServiceHelper.loadFactories(JlamaLanguageModelBuilderFactory.class).iterator();
        if (iterator.hasNext()) {
            JlamaLanguageModelBuilderFactory factory = (JlamaLanguageModelBuilderFactory)iterator.next();
            return (JlamaLanguageModelBuilder)factory.get();
        }
        return new JlamaLanguageModelBuilder();
    }

    public Response<String> generate(String prompt) {
        Generator.Response r = this.model.generate(this.id, PromptContext.of((String)prompt), this.temperature.floatValue(), this.maxTokens.intValue(), (token, time) -> {});
        return Response.from((Object)r.responseText, (TokenUsage)new TokenUsage(Integer.valueOf(r.promptTokens), Integer.valueOf(r.generatedTokens)), (FinishReason)JlamaLanguageModel.toFinishReason(r.finishReason));
    }

    public static class JlamaLanguageModelBuilder {
        private Path modelCachePath;
        private String modelName;
        private String authToken;
        private Integer threadCount;
        private Boolean quantizeModelAtRuntime;
        private Path workingDirectory;
        private DType workingQuantizedType;
        private Float temperature;
        private Integer maxTokens;

        public JlamaLanguageModelBuilder modelCachePath(Path modelCachePath) {
            this.modelCachePath = modelCachePath;
            return this;
        }

        public JlamaLanguageModelBuilder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public JlamaLanguageModelBuilder authToken(String authToken) {
            this.authToken = authToken;
            return this;
        }

        public JlamaLanguageModelBuilder threadCount(Integer threadCount) {
            this.threadCount = threadCount;
            return this;
        }

        public JlamaLanguageModelBuilder quantizeModelAtRuntime(Boolean quantizeModelAtRuntime) {
            this.quantizeModelAtRuntime = quantizeModelAtRuntime;
            return this;
        }

        public JlamaLanguageModelBuilder workingDirectory(Path workingDirectory) {
            this.workingDirectory = workingDirectory;
            return this;
        }

        public JlamaLanguageModelBuilder workingQuantizedType(DType workingQuantizedType) {
            this.workingQuantizedType = workingQuantizedType;
            return this;
        }

        public JlamaLanguageModelBuilder temperature(Float temperature) {
            this.temperature = temperature;
            return this;
        }

        public JlamaLanguageModelBuilder maxTokens(Integer maxTokens) {
            this.maxTokens = maxTokens;
            return this;
        }

        public JlamaLanguageModel build() {
            return new JlamaLanguageModel(this.modelCachePath, this.modelName, this.authToken, this.threadCount, this.quantizeModelAtRuntime, this.workingDirectory, this.workingQuantizedType, this.temperature, this.maxTokens);
        }

        public String toString() {
            return "JlamaLanguageModel.JlamaLanguageModelBuilder(modelCachePath=" + String.valueOf(this.modelCachePath) + ", modelName=" + this.modelName + ", authToken=" + this.authToken + ", threadCount=" + this.threadCount + ", quantizeModelAtRuntime=" + this.quantizeModelAtRuntime + ", workingDirectory=" + String.valueOf(this.workingDirectory) + ", workingQuantizedType=" + String.valueOf(this.workingQuantizedType) + ", temperature=" + this.temperature + ", maxTokens=" + this.maxTokens + ")";
        }
    }
}

