/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxOpParams;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import com.sun.jna.Pointer;

public class MxNDArray16
extends MxNDArray {
    MxNDArray16(MxNDManager manager, Pointer handle, Device device, Shape shape, DataType dataType, boolean hasGradient) {
        super(manager, handle, device, shape, dataType, hasGradient);
    }

    MxNDArray16(MxNDManager manager, Pointer handle) {
        super(manager, handle);
    }

    @Override
    public NDArray zerosLike() {
        return this.manager.invoke("_np_zeros_like", (NDArray)this, null);
    }

    @Override
    public NDArray onesLike() {
        return this.manager.invoke("_np_ones_like", (NDArray)this, null);
    }

    @Override
    public NDArray argSort(int axis, boolean ascending) {
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        params.addParam("is_ascend", ascending);
        params.setDataType(DataType.INT32);
        return this.manager.invoke("argsort", (NDArray)this, params).toType(DataType.INT64, false);
    }

    @Override
    public NDArray sort(int axis) {
        if (this.isEmpty() || this.isScalar()) {
            long dim = this.getShape().dimension();
            if ((long)axis >= dim) {
                throw new IllegalArgumentException("axis " + axis + "is out of bounds for array of dimension " + dim);
            }
            return this.duplicate();
        }
        MxOpParams params = new MxOpParams();
        params.addParam("axis", axis);
        return this.manager.invoke("sort", (NDArray)this, params);
    }

    @Override
    public NDArray sort() {
        if (this.isEmpty() || this.isScalar()) {
            return this.duplicate();
        }
        return this.manager.invoke("sort", (NDArray)this, null);
    }

    @Override
    public NDArray argMax() {
        try (NDArray array = super.argMax();){
            NDArray nDArray = array.toType(DataType.INT64, true);
            return nDArray;
        }
    }

    @Override
    public NDArray argMax(int axis) {
        try (NDArray array = super.argMax(axis);){
            NDArray nDArray = array.toType(DataType.INT64, true);
            return nDArray;
        }
    }

    @Override
    public NDArray argMin() {
        try (NDArray array = super.argMin();){
            NDArray nDArray = array.toType(DataType.INT64, true);
            return nDArray;
        }
    }

    @Override
    public NDArray argMin(int axis) {
        try (NDArray array = super.argMin(axis);){
            NDArray nDArray = array.toType(DataType.INT64, true);
            return nDArray;
        }
    }

    @Override
    public NDArray expandDims(int axis) {
        if (this.isScalar()) {
            return this.reshape(new long[]{1L});
        }
        return super.expandDims(axis);
    }

    @Override
    public NDArray matMul(NDArray other) {
        throw new UnsupportedOperationException("matMul is not supported in MXNet 1.6.0");
    }

    @Override
    public NDArray isNaN() {
        return this.manager.invoke("_npi_not_equal", new NDArray[]{this, this}, null);
    }

    @Override
    public NDArray broadcast(Shape shape) {
        MxOpParams params = new MxOpParams();
        params.setShape(shape);
        return this.manager.invoke("_np_broadcast_to", (NDArray)this, params);
    }
}

