/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow.internal.types.registry;

import java.util.HashMap;
import java.util.Map;
import org.tensorflow.TensorMapper;
import org.tensorflow.internal.types.registry.TensorTypeInfo;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TBfloat16;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TString;
import org.tensorflow.types.TUint16;
import org.tensorflow.types.TUint8;
import org.tensorflow.types.annotation.TensorType;
import org.tensorflow.types.family.TType;

public final class TensorTypeRegistry {
    private static final Map<Integer, TensorTypeInfo<?>> TYPES_BY_CODE = new HashMap();
    private static final Map<Class<? extends TType>, TensorTypeInfo<?>> TYPES_BY_CLASS = new HashMap();

    public static <T extends TType> TensorTypeInfo<T> find(DataType dataType) {
        TensorTypeInfo<?> typeInfo = TYPES_BY_CODE.get(dataType.getNumber());
        if (typeInfo == null) {
            throw new IllegalArgumentException("No tensor type has been registered for data type " + dataType);
        }
        return typeInfo;
    }

    public static <T extends TType> TensorTypeInfo<T> find(Class<T> type) {
        TensorTypeInfo<?> typeInfo = TYPES_BY_CLASS.get(type);
        if (typeInfo == null) {
            throw new IllegalArgumentException("Class \"" + type.getName() + "\" is not registered as a tensor type");
        }
        return typeInfo;
    }

    private static <T extends TType> void register(Class<T> type) {
        TensorMapper<?> mapper;
        TensorType typeAnnot = type.getDeclaredAnnotation(TensorType.class);
        if (typeAnnot == null) {
            throw new IllegalArgumentException("Class \"" + type.getName() + "\" must be annotated with @TensorType to be registered as a tensor type");
        }
        try {
            mapper = typeAnnot.mapperClass().newInstance();
        }
        catch (ReflectiveOperationException e) {
            throw new IllegalArgumentException("Class \"" + type.getName() + "\" must have a public parameter-less constructor to be used as a tensor mapper");
        }
        TensorTypeInfo<T> typeInfo = new TensorTypeInfo<T>(type, typeAnnot.dataType(), typeAnnot.byteSize(), mapper);
        TYPES_BY_CLASS.put(type, typeInfo);
        TYPES_BY_CODE.put(typeInfo.dataType().getNumber(), typeInfo);
        TYPES_BY_CODE.put(typeInfo.dataType().getNumber() + 100, typeInfo);
    }

    static {
        TensorTypeRegistry.register(TBool.class);
        TensorTypeRegistry.register(TFloat64.class);
        TensorTypeRegistry.register(TFloat32.class);
        TensorTypeRegistry.register(TFloat16.class);
        TensorTypeRegistry.register(TInt32.class);
        TensorTypeRegistry.register(TInt64.class);
        TensorTypeRegistry.register(TString.class);
        TensorTypeRegistry.register(TUint8.class);
        TensorTypeRegistry.register(TUint16.class);
        TensorTypeRegistry.register(TBfloat16.class);
    }
}

