/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.imports.graphmapper.tf.tensors;

import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.TensorProto;

public class TFTensorMappers {
    private TFTensorMappers() {
    }

    public static TFTensorMapper<?, ?> newMapper(TensorProto tp) {
        switch (tp.getDtype()) {
            case DT_HALF: {
                return new Float16TensorMapper(tp);
            }
            case DT_FLOAT: {
                return new Float32TensorMapper(tp);
            }
            case DT_DOUBLE: {
                return new Float64TensorMapper(tp);
            }
            case DT_BFLOAT16: {
                return new BFloat16TensorMapper(tp);
            }
            case DT_INT8: {
                return new Int8TensorMapper(tp);
            }
            case DT_INT16: {
                return new Int16TensorMapper(tp);
            }
            case DT_INT32: {
                return new Int32TensorMapper(tp);
            }
            case DT_INT64: {
                return new Int64TensorMapper(tp);
            }
            case DT_STRING: {
                return new StringTensorMapper(tp);
            }
            case DT_BOOL: {
                return new BoolTensorMapper(tp);
            }
            case DT_UINT8: {
                return new UInt8TensorMapper(tp);
            }
            case DT_UINT16: {
                return new UInt16TensorMapper(tp);
            }
            case DT_UINT32: {
                return new UInt32TensorMapper(tp);
            }
            case DT_UINT64: {
                return new UInt64TensorMapper(tp);
            }
            case DT_QINT8: 
            case DT_QUINT8: 
            case DT_QINT32: 
            case DT_QINT16: 
            case DT_QUINT16: {
                throw new IllegalStateException("Unable to map quantized type: " + (Object)((Object)tp.getDtype()));
            }
            case DT_COMPLEX64: 
            case DT_COMPLEX128: {
                throw new IllegalStateException("Unable to map complex type: " + (Object)((Object)tp.getDtype()));
            }
            case DT_FLOAT_REF: 
            case DT_DOUBLE_REF: 
            case DT_INT32_REF: 
            case DT_UINT8_REF: 
            case DT_INT16_REF: 
            case DT_INT8_REF: 
            case DT_STRING_REF: 
            case DT_COMPLEX64_REF: 
            case DT_INT64_REF: 
            case DT_BOOL_REF: 
            case DT_QINT8_REF: 
            case DT_QUINT8_REF: 
            case DT_QINT32_REF: 
            case DT_BFLOAT16_REF: 
            case DT_QINT16_REF: 
            case DT_QUINT16_REF: 
            case DT_UINT16_REF: 
            case DT_COMPLEX128_REF: 
            case DT_HALF_REF: 
            case DT_RESOURCE_REF: 
            case DT_VARIANT_REF: 
            case DT_UINT32_REF: 
            case DT_UINT64_REF: {
                throw new IllegalStateException("Unable to map reference type: " + (Object)((Object)tp.getDtype()));
            }
        }
        throw new IllegalStateException("Unable to map type: " + (Object)((Object)tp.getDtype()));
    }

    public static class BoolTensorMapper
    extends BaseTensorMapper<boolean[], ByteBuffer> {
        public BoolTensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getBoolValCount();
        }

        @Override
        public boolean[] newArray(int length) {
            return new boolean[length];
        }

        @Override
        public ByteBuffer getBuffer(ByteBuffer bb) {
            throw new UnsupportedOperationException("Not supported for String types");
        }

        @Override
        public void getValue(boolean[] jArr, int i) {
            jArr[i] = this.tfTensor.getBoolVal(i);
        }

        @Override
        public void getValue(boolean[] jArr, ByteBuffer buffer, int i) {
            throw new UnsupportedOperationException("Not supported for boolean types");
        }

        @Override
        public INDArray arrayFor(long[] shape, boolean[] jArr) {
            return Nd4j.create(jArr).reshape(shape);
        }
    }

    public static class StringTensorMapper
    extends BaseTensorMapper<String[], ByteBuffer> {
        public StringTensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getStringValCount();
        }

        @Override
        public String[] newArray(int length) {
            return new String[length];
        }

        @Override
        public ByteBuffer getBuffer(ByteBuffer bb) {
            throw new UnsupportedOperationException("Not supported for String types");
        }

        @Override
        public void getValue(String[] jArr, int i) {
            jArr[i] = this.tfTensor.getStringVal(i).toStringUtf8();
        }

        @Override
        public void getValue(String[] jArr, ByteBuffer buffer, int i) {
            throw new UnsupportedOperationException("Not supported for String types");
        }

        @Override
        public INDArray arrayFor(long[] shape, String[] jArr) {
            return Nd4j.create(jArr).reshape(shape);
        }
    }

    public static class UInt64TensorMapper
    extends BaseTensorMapper<long[], LongBuffer> {
        public UInt64TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getInt64ValCount();
        }

        @Override
        public long[] newArray(int length) {
            return new long[length];
        }

        @Override
        public LongBuffer getBuffer(ByteBuffer bb) {
            return bb.asLongBuffer();
        }

        @Override
        public void getValue(long[] jArr, int i) {
            jArr[i] = this.tfTensor.getInt64Val(i);
        }

        @Override
        public void getValue(long[] jArr, LongBuffer buffer, int i) {
            jArr[i] = buffer.get(i);
        }

        @Override
        public INDArray arrayFor(long[] shape, long[] jArr) {
            DataType dt = this.dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape, Nd4j.getStrides(shape, 'c'), 0L, 'c', dt);
        }
    }

    public static class UInt32TensorMapper
    extends BaseTensorMapper<long[], IntBuffer> {
        public UInt32TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getInt64ValCount();
        }

        @Override
        public long[] newArray(int length) {
            return new long[length];
        }

        @Override
        public IntBuffer getBuffer(ByteBuffer bb) {
            return bb.asIntBuffer();
        }

        @Override
        public void getValue(long[] jArr, int i) {
            jArr[i] = this.tfTensor.getInt64Val(i);
        }

        @Override
        public void getValue(long[] jArr, IntBuffer buffer, int i) {
            int b = buffer.get(i);
            jArr[i] = (long)b & 0xFFFFFFFFL;
        }

        @Override
        public INDArray arrayFor(long[] shape, long[] jArr) {
            DataType dt = this.dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape, Nd4j.getStrides(shape, 'c'), 0L, 'c', dt);
        }
    }

    public static class UInt16TensorMapper
    extends BaseTensorMapper<int[], ShortBuffer> {
        public UInt16TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getIntValCount();
        }

        @Override
        public int[] newArray(int length) {
            return new int[length];
        }

        @Override
        public ShortBuffer getBuffer(ByteBuffer bb) {
            return bb.asShortBuffer();
        }

        @Override
        public void getValue(int[] jArr, int i) {
            jArr[i] = this.tfTensor.getIntVal(i);
        }

        @Override
        public void getValue(int[] jArr, ShortBuffer buffer, int i) {
            short b = buffer.get(i);
            jArr[i] = b & 0xFFFF;
        }

        @Override
        public INDArray arrayFor(long[] shape, int[] jArr) {
            DataType dt = this.dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape, Nd4j.getStrides(shape, 'c'), 0L, 'c', dt);
        }
    }

    public static class UInt8TensorMapper
    extends BaseTensorMapper<int[], ByteBuffer> {
        public UInt8TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getIntValCount();
        }

        @Override
        public int[] newArray(int length) {
            return new int[length];
        }

        @Override
        public ByteBuffer getBuffer(ByteBuffer bb) {
            return bb;
        }

        @Override
        public void getValue(int[] jArr, int i) {
            jArr[i] = this.tfTensor.getIntVal(i);
        }

        @Override
        public void getValue(int[] jArr, ByteBuffer buffer, int i) {
            byte b = buffer.get(i);
            jArr[i] = b & 0xFF;
        }

        @Override
        public INDArray arrayFor(long[] shape, int[] jArr) {
            DataType dt = this.dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape, Nd4j.getStrides(shape, 'c'), 0L, 'c', dt);
        }
    }

    public static class Int64TensorMapper
    extends BaseTensorMapper<long[], LongBuffer> {
        public Int64TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getInt64ValCount();
        }

        @Override
        public long[] newArray(int length) {
            return new long[length];
        }

        @Override
        public LongBuffer getBuffer(ByteBuffer bb) {
            return bb.asLongBuffer();
        }

        @Override
        public void getValue(long[] jArr, int i) {
            jArr[i] = this.tfTensor.getInt64Val(i);
        }

        @Override
        public void getValue(long[] jArr, LongBuffer buffer, int i) {
            jArr[i] = buffer.get(i);
        }

        @Override
        public INDArray arrayFor(long[] shape, long[] jArr) {
            DataType dt = this.dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape, Nd4j.getStrides(shape, 'c'), 0L, 'c', dt);
        }
    }

    public static class Int32TensorMapper
    extends BaseTensorMapper<int[], IntBuffer> {
        public Int32TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getIntValCount();
        }

        @Override
        public int[] newArray(int length) {
            return new int[length];
        }

        @Override
        public IntBuffer getBuffer(ByteBuffer bb) {
            return bb.asIntBuffer();
        }

        @Override
        public void getValue(int[] jArr, int i) {
            jArr[i] = this.tfTensor.getIntVal(i);
        }

        @Override
        public void getValue(int[] jArr, IntBuffer buffer, int i) {
            jArr[i] = buffer.get(i);
        }

        @Override
        public INDArray arrayFor(long[] shape, int[] jArr) {
            DataType dt = this.dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape, Nd4j.getStrides(shape, 'c'), 0L, 'c', dt);
        }
    }

    public static class Int16TensorMapper
    extends BaseTensorMapper<int[], ShortBuffer> {
        public Int16TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getIntValCount();
        }

        @Override
        public int[] newArray(int length) {
            return new int[length];
        }

        @Override
        public ShortBuffer getBuffer(ByteBuffer bb) {
            return bb.asShortBuffer();
        }

        @Override
        public void getValue(int[] jArr, int i) {
            jArr[i] = this.tfTensor.getIntVal(i);
        }

        @Override
        public void getValue(int[] jArr, ShortBuffer buffer, int i) {
            jArr[i] = buffer.get(i);
        }

        @Override
        public INDArray arrayFor(long[] shape, int[] jArr) {
            DataType dt = this.dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape, Nd4j.getStrides(shape, 'c'), 0L, 'c', dt);
        }
    }

    public static class Int8TensorMapper
    extends BaseTensorMapper<int[], ByteBuffer> {
        public Int8TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getIntValCount();
        }

        @Override
        public int[] newArray(int length) {
            return new int[length];
        }

        @Override
        public ByteBuffer getBuffer(ByteBuffer bb) {
            return bb;
        }

        @Override
        public void getValue(int[] jArr, int i) {
            jArr[i] = this.tfTensor.getIntVal(i);
        }

        @Override
        public void getValue(int[] jArr, ByteBuffer buffer, int i) {
            jArr[i] = buffer.get(i);
        }

        @Override
        public INDArray arrayFor(long[] shape, int[] jArr) {
            DataType dt = this.dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape, Nd4j.getStrides(shape, 'c'), 0L, 'c', dt);
        }
    }

    public static class BFloat16TensorMapper
    extends BaseTensorMapper<float[], ShortBuffer> {
        public BFloat16TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getHalfValCount();
        }

        @Override
        public float[] newArray(int length) {
            return new float[length];
        }

        @Override
        public ShortBuffer getBuffer(ByteBuffer bb) {
            return bb.asShortBuffer();
        }

        @Override
        public void getValue(float[] jArr, int i) {
            int asIntBytes = this.tfTensor.getHalfVal(i);
            jArr[i] = Bfloat16ArrayIndexer.toFloat((int)asIntBytes);
        }

        @Override
        public void getValue(float[] jArr, ShortBuffer buffer, int i) {
            throw new UnsupportedOperationException("Not yet implemnted: BFP16 reading from buffer");
        }

        @Override
        public INDArray arrayFor(long[] shape, float[] jArr) {
            if (jArr.length == 1 && ArrayUtil.prod((long[])shape) > 1) {
                return Nd4j.createUninitialized(DataType.HALF, shape).assign(Float.valueOf(jArr[0]));
            }
            return Nd4j.create(jArr, shape, 'c').castTo(DataType.BFLOAT16);
        }
    }

    public static class Float64TensorMapper
    extends BaseTensorMapper<double[], DoubleBuffer> {
        public Float64TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getDoubleValCount();
        }

        @Override
        public double[] newArray(int length) {
            return new double[length];
        }

        @Override
        public DoubleBuffer getBuffer(ByteBuffer bb) {
            return bb.asDoubleBuffer();
        }

        @Override
        public void getValue(double[] jArr, int i) {
            jArr[i] = this.tfTensor.getDoubleVal(i);
        }

        @Override
        public void getValue(double[] jArr, DoubleBuffer buffer, int i) {
            jArr[i] = buffer.get(i);
        }

        @Override
        public INDArray arrayFor(long[] shape, double[] jArr) {
            if (jArr.length == 1 && ArrayUtil.prod((long[])shape) > 1) {
                return Nd4j.valueArrayOf(shape, jArr[0]);
            }
            return Nd4j.create(jArr, shape, 'c');
        }
    }

    public static class Float32TensorMapper
    extends BaseTensorMapper<float[], FloatBuffer> {
        public Float32TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getFloatValCount();
        }

        @Override
        public float[] newArray(int length) {
            return new float[length];
        }

        @Override
        public FloatBuffer getBuffer(ByteBuffer bb) {
            return bb.asFloatBuffer();
        }

        @Override
        public void getValue(float[] jArr, int i) {
            jArr[i] = this.tfTensor.getFloatVal(i);
        }

        @Override
        public void getValue(float[] jArr, FloatBuffer buffer, int i) {
            jArr[i] = buffer.get(i);
        }

        @Override
        public INDArray arrayFor(long[] shape, float[] jArr) {
            if (jArr.length == 1 && ArrayUtil.prod((long[])shape) > 1) {
                return Nd4j.valueArrayOf(shape, jArr[0]);
            }
            return Nd4j.create(jArr, shape, 'c');
        }
    }

    public static class Float16TensorMapper
    extends BaseTensorMapper<float[], Buffer> {
        public Float16TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override
        public int valueCount() {
            return this.tfTensor.getHalfValCount();
        }

        @Override
        public float[] newArray(int length) {
            return new float[length];
        }

        @Override
        public Buffer getBuffer(ByteBuffer bb) {
            throw new UnsupportedOperationException("Not yet implemnted: FP16 reading from buffer");
        }

        @Override
        public void getValue(float[] jArr, int i) {
            int asIntBytes = this.tfTensor.getHalfVal(i);
            jArr[i] = HalfIndexer.toFloat((int)asIntBytes);
        }

        @Override
        public void getValue(float[] jArr, Buffer buffer, int i) {
            throw new UnsupportedOperationException("Not yet implemented: FP16 reading from buffer");
        }

        @Override
        public INDArray arrayFor(long[] shape, float[] jArr) {
            if (jArr.length == 1 && ArrayUtil.prod((long[])shape) > 1) {
                return Nd4j.createUninitialized(DataType.HALF, shape).assign(Float.valueOf(jArr[0]));
            }
            return Nd4j.create(jArr, shape, 'c').castTo(DataType.HALF);
        }
    }

    public static abstract class BaseTensorMapper<T, U extends Buffer>
    implements TFTensorMapper<T, U> {
        protected TensorProto tfTensor;

        public BaseTensorMapper(TensorProto tensorProto) {
            this.tfTensor = tensorProto;
        }

        @Override
        public DataType dataType() {
            return ArrayOptionsHelper.convertToDataType(this.tfTensor.getDtype());
        }

        @Override
        public long[] shape() {
            int dims = this.tfTensor.getTensorShape().getDimCount();
            long[] arrayShape = new long[dims];
            for (int e = 0; e < dims; ++e) {
                arrayShape[e] = this.tfTensor.getTensorShape().getDim(e).getSize();
            }
            return arrayShape;
        }

        @Override
        public boolean isEmpty() {
            return this.valueSource() == TFTensorMapper.ValueSource.EMPTY;
        }

        @Override
        public TFTensorMapper.ValueSource valueSource() {
            if (this.valueCount() > 0) {
                return TFTensorMapper.ValueSource.VALUE_COUNT;
            }
            if (this.tfTensor.getTensorContent() != null && this.tfTensor.getTensorContent().size() > 0) {
                return TFTensorMapper.ValueSource.BINARY;
            }
            return TFTensorMapper.ValueSource.EMPTY;
        }

        @Override
        public INDArray toNDArray() {
            INDArray out;
            DataType dt = this.dataType();
            TFTensorMapper.ValueSource vs = this.valueSource();
            long[] shape = this.shape();
            switch (vs) {
                case EMPTY: {
                    out = Nd4j.create(dt, shape);
                    break;
                }
                case VALUE_COUNT: {
                    int n = this.valueCount();
                    Object array = this.newArray(n);
                    for (int i = 0; i < n; ++i) {
                        this.getValue(array, i);
                    }
                    out = this.arrayFor(shape, array);
                    break;
                }
                case BINARY: {
                    Object buffer = this.getBuffer(this.tfTensor.getTensorContent().asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()));
                    int m = ((Buffer)buffer).capacity();
                    Object array2 = this.newArray(m);
                    for (int i = 0; i < m; ++i) {
                        this.getValue(array2, buffer, i);
                    }
                    out = this.arrayFor(shape, array2);
                    break;
                }
                default: {
                    throw new RuntimeException("Error converting TF tensor to INDArray");
                }
            }
            return out;
        }
    }
}

