/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.repository.zoo;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.nn.Block;
import ai.djl.repository.zoo.DefaultModelZoo;
import ai.djl.repository.zoo.ModelLoader;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.DefaultTranslatorFactory;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import com.google.gson.Gson;
import java.io.IOException;
import java.net.MalformedURLException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Criteria<I, O> {
    private Application application;
    private Class<I> inputClass;
    private Class<O> outputClass;
    private String engine;
    private Device device;
    private String groupId;
    private String artifactId;
    private ModelZoo modelZoo;
    private Map<String, String> filters;
    private Map<String, Object> arguments;
    private Map<String, String> options;
    private TranslatorFactory factory;
    private Block block;
    private String modelName;
    private Progress progress;
    private List<ModelLoader> resolvedLoaders;

    Criteria(Builder<I, O> builder) {
        this.application = builder.application;
        this.inputClass = builder.inputClass;
        this.outputClass = builder.outputClass;
        this.engine = builder.engine;
        this.device = builder.device;
        this.groupId = builder.groupId;
        this.artifactId = builder.artifactId;
        this.modelZoo = builder.modelZoo;
        this.filters = builder.filters;
        this.arguments = builder.arguments;
        this.options = builder.options;
        this.factory = builder.factory;
        this.block = builder.block;
        this.modelName = builder.modelName;
        this.progress = builder.progress;
    }

    public boolean isDownloaded() throws IOException, ModelNotFoundException {
        if (this.resolvedLoaders == null) {
            this.resolvedLoaders = this.resolveModelLoaders();
        }
        for (ModelLoader loader : this.resolvedLoaders) {
            if (loader.isDownloaded(this)) continue;
            return false;
        }
        return true;
    }

    public void downloadModel() throws ModelNotFoundException, IOException {
        if (!this.isDownloaded()) {
            for (ModelLoader loader : this.resolvedLoaders) {
                loader.downloadModel(this, this.progress);
            }
        }
    }

    public ZooModel<I, O> loadModel() throws IOException, ModelNotFoundException, MalformedModelException {
        if (this.resolvedLoaders == null) {
            this.resolvedLoaders = this.resolveModelLoaders();
        }
        Logger logger = LoggerFactory.getLogger(ModelZoo.class);
        ModelNotFoundException lastException = null;
        for (ModelLoader loader : this.resolvedLoaders) {
            try {
                return loader.loadModel(this);
            }
            catch (ModelNotFoundException e) {
                lastException = e;
                logger.trace("", (Throwable)e);
                logger.debug("{} for ModelLoader: {}:{}", new Object[]{e.getMessage(), loader.getGroupId(), loader.getArtifactId()});
            }
        }
        throw new ModelNotFoundException("No model with the specified URI or the matching Input/Output type is found.", lastException);
    }

    public Application getApplication() {
        return this.application;
    }

    public Class<I> getInputClass() {
        return this.inputClass;
    }

    public Class<O> getOutputClass() {
        return this.outputClass;
    }

    public String getEngine() {
        return this.engine;
    }

    public Device getDevice() {
        return this.device;
    }

    public String getGroupId() {
        return this.groupId;
    }

    public String getArtifactId() {
        return this.artifactId;
    }

    public ModelZoo getModelZoo() {
        return this.modelZoo;
    }

    public Map<String, String> getFilters() {
        return this.filters;
    }

    public Map<String, Object> getArguments() {
        return this.arguments;
    }

    public Map<String, String> getOptions() {
        return this.options;
    }

    public TranslatorFactory getTranslatorFactory() {
        return this.factory;
    }

    public Block getBlock() {
        return this.block;
    }

    public String getModelName() {
        return this.modelName;
    }

    public Progress getProgress() {
        return this.progress;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Criteria:\n");
        if (this.application != null) {
            sb.append("\tApplication: ").append(this.application).append('\n');
        }
        sb.append("\tInput: ").append(this.inputClass);
        sb.append("\n\tOutput: ").append(this.outputClass).append('\n');
        if (this.engine != null) {
            sb.append("\tEngine: ").append(this.engine).append('\n');
        }
        if (this.modelZoo != null) {
            sb.append("\tModelZoo: ").append(this.modelZoo.getGroupId()).append('\n');
        }
        if (this.groupId != null) {
            sb.append("\tGroupID: ").append(this.groupId).append('\n');
        }
        if (this.artifactId != null) {
            sb.append("\tArtifactId: ").append(this.artifactId).append('\n');
        }
        if (this.filters != null) {
            sb.append("\tFilter: ").append(JsonUtils.GSON.toJson(this.filters)).append('\n');
        }
        if (this.arguments != null) {
            Gson gson = JsonUtils.builder().excludeFieldsWithoutExposeAnnotation().create();
            sb.append("\tArguments: ").append(gson.toJson(this.arguments)).append('\n');
        }
        if (this.options != null) {
            sb.append("\tOptions: ").append(JsonUtils.GSON.toJson(this.options)).append('\n');
        }
        if (this.factory == null) {
            sb.append("\tNo translator supplied\n");
        }
        return sb.toString();
    }

    public Builder<I, O> toBuilder() {
        return Criteria.builder().setTypes(this.inputClass, this.outputClass).optApplication(this.application).optEngine(this.engine).optDevice(this.device).optGroupId(this.groupId).optArtifactId(this.artifactId).optModelZoo(this.modelZoo).optFilters(this.filters).optArguments(this.arguments).optOptions(this.options).optTranslatorFactory(this.factory).optBlock(this.block).optModelName(this.modelName).optProgress(this.progress);
    }

    public static Builder<?, ?> builder() {
        return new Builder();
    }

    private List<ModelLoader> resolveModelLoaders() throws ModelNotFoundException {
        if (this.inputClass == null || this.outputClass == null) {
            throw new IllegalArgumentException("inputClass and outputClass are required.");
        }
        Logger logger = LoggerFactory.getLogger(ModelZoo.class);
        logger.debug("Loading model with {}", (Object)this);
        ArrayList<ModelZoo> list = new ArrayList<ModelZoo>();
        if (this.modelZoo != null) {
            logger.debug("Searching model in specified model zoo: {}", (Object)this.modelZoo.getGroupId());
            if (this.groupId != null && !this.modelZoo.getGroupId().equals(this.groupId)) {
                throw new ModelNotFoundException("groupId conflict with ModelZoo criteria." + this.modelZoo.getGroupId() + " v.s. " + this.groupId);
            }
            Set<String> supportedEngine = this.modelZoo.getSupportedEngines();
            if (this.engine != null && !supportedEngine.contains(this.engine)) {
                throw new ModelNotFoundException("ModelZoo doesn't support specified engine: " + this.engine);
            }
            list.add(this.modelZoo);
        } else {
            for (ModelZoo zoo : ModelZoo.listModelZoo()) {
                if (this.groupId != null && !zoo.getGroupId().equals(this.groupId)) {
                    logger.debug("Ignore ModelZoo {} by groupId: {}", (Object)zoo.getGroupId(), (Object)this.groupId);
                    continue;
                }
                Set<String> supportedEngine = zoo.getSupportedEngines();
                if (this.engine != null && !supportedEngine.contains(this.engine)) {
                    logger.debug("Ignore ModelZoo {} by engine: {}", (Object)zoo.getGroupId(), (Object)this.engine);
                    continue;
                }
                list.add(zoo);
            }
        }
        ArrayList<ModelLoader> loaders = new ArrayList<ModelLoader>();
        for (ModelZoo zoo : list) {
            String loaderGroupId = zoo.getGroupId();
            for (ModelLoader loader : zoo.getModelLoaders()) {
                Application app = loader.getApplication();
                String loaderArtifactId = loader.getArtifactId();
                logger.debug("Checking ModelLoader: {}", (Object)loader);
                if (this.artifactId != null && !this.artifactId.equals(loaderArtifactId)) {
                    logger.debug("artifactId mismatch for ModelLoader: {}:{}", (Object)loaderGroupId, (Object)loaderArtifactId);
                    continue;
                }
                if (this.application != Application.UNDEFINED && app != Application.UNDEFINED && !app.matches(this.application)) {
                    logger.debug("application mismatch for ModelLoader: {}:{}", (Object)loaderGroupId, (Object)loaderArtifactId);
                    continue;
                }
                loaders.add(loader);
            }
        }
        if (loaders.isEmpty()) {
            throw new ModelNotFoundException("No model matching the criteria is found.");
        }
        return loaders;
    }

    public static final class Builder<I, O> {
        Application application;
        Class<I> inputClass;
        Class<O> outputClass;
        String engine;
        Device device;
        String groupId;
        String artifactId;
        ModelZoo modelZoo;
        Map<String, String> filters;
        Map<String, Object> arguments;
        Map<String, String> options;
        TranslatorFactory factory;
        Block block;
        String modelName;
        Progress progress;
        Translator<I, O> translator;

        Builder() {
            this.application = Application.UNDEFINED;
        }

        private Builder(Class<I> inputClass, Class<O> outputClass, Builder<?, ?> parent) {
            this.inputClass = inputClass;
            this.outputClass = outputClass;
            this.application = parent.application;
            this.engine = parent.engine;
            this.device = parent.device;
            this.groupId = parent.groupId;
            this.artifactId = parent.artifactId;
            this.modelZoo = parent.modelZoo;
            this.filters = parent.filters;
            this.arguments = parent.arguments;
            this.options = parent.options;
            this.factory = parent.factory;
            this.block = parent.block;
            this.modelName = parent.modelName;
            this.progress = parent.progress;
            this.translator = parent.translator;
        }

        public <P, Q> Builder<P, Q> setTypes(Class<P> inputClass, Class<Q> outputClass) {
            return new Builder<P, Q>(inputClass, outputClass, this);
        }

        public Builder<I, O> optApplication(Application application) {
            this.application = application;
            return this;
        }

        public Builder<I, O> optEngine(String engine) {
            this.engine = engine;
            return this;
        }

        public Builder<I, O> optDevice(Device device) {
            this.device = device;
            return this;
        }

        public Builder<I, O> optGroupId(String groupId) {
            this.groupId = groupId;
            return this;
        }

        public Builder<I, O> optArtifactId(String artifactId) {
            if (artifactId != null && artifactId.contains(":")) {
                String[] tokens = artifactId.split(":", -1);
                this.groupId = tokens[0].isEmpty() ? null : tokens[0];
                this.artifactId = tokens[1].isEmpty() ? null : tokens[1];
            } else {
                this.artifactId = artifactId;
            }
            return this;
        }

        public Builder<I, O> optModelUrls(String modelUrls) {
            if (modelUrls != null) {
                this.modelZoo = new DefaultModelZoo(modelUrls);
            }
            return this;
        }

        public Builder<I, O> optModelPath(Path modelPath) {
            if (modelPath != null) {
                try {
                    this.modelZoo = new DefaultModelZoo(modelPath.toUri().toURL().toString());
                }
                catch (MalformedURLException e) {
                    throw new AssertionError("Invalid model path: " + modelPath, e);
                }
            }
            return this;
        }

        public Builder<I, O> optModelZoo(ModelZoo modelZoo) {
            this.modelZoo = modelZoo;
            return this;
        }

        public Builder<I, O> optFilters(Map<String, String> filters) {
            this.filters = filters;
            return this;
        }

        public Builder<I, O> optFilter(String key, String value) {
            if (this.filters == null) {
                this.filters = new HashMap<String, String>();
            }
            this.filters.put(key, value);
            return this;
        }

        public Builder<I, O> optBlock(Block block) {
            this.block = block;
            return this;
        }

        public Builder<I, O> optModelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public Builder<I, O> optArguments(Map<String, Object> arguments) {
            this.arguments = arguments;
            return this;
        }

        public Builder<I, O> optArgument(String key, Object value) {
            if (this.arguments == null) {
                this.arguments = new HashMap<String, Object>();
            }
            this.arguments.put(key, value);
            return this;
        }

        public Builder<I, O> optOptions(Map<String, String> options) {
            this.options = options;
            return this;
        }

        public Builder<I, O> optOption(String key, String value) {
            if (this.options == null) {
                this.options = new HashMap<String, String>();
            }
            this.options.put(key, value);
            return this;
        }

        public Builder<I, O> optTranslator(Translator<I, O> translator) {
            this.factory = null;
            this.translator = translator;
            return this;
        }

        public Builder<I, O> optTranslatorFactory(TranslatorFactory factory) {
            this.translator = null;
            this.factory = factory;
            return this;
        }

        public Builder<I, O> optProgress(Progress progress) {
            this.progress = progress;
            return this;
        }

        public Criteria<I, O> build() {
            if (this.factory == null && this.translator != null) {
                DefaultTranslatorFactory f = new DefaultTranslatorFactory();
                f.registerTranslator(this.inputClass, this.outputClass, this.translator);
                this.factory = f;
            }
            return new Criteria(this);
        }
    }
}

