/*
 * Decompiled with CFR 0.152.
 */
package smile.tensor;

import java.io.Serializable;
import java.lang.foreign.MemorySegment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.linalg.Order;
import smile.linalg.Transpose;
import smile.linalg.UPLO;
import smile.linalg.blas.cblas_h;
import smile.linalg.lapack.clapack_h;
import smile.math.MathEx;
import smile.tensor.DenseMatrix;
import smile.tensor.Matrix;
import smile.tensor.ScalarType;
import smile.tensor.SymmMatrix32;
import smile.tensor.SymmMatrix64;
import smile.tensor.Vector;

public abstract class SymmMatrix
implements Matrix,
Serializable {
    private static final Logger logger = LoggerFactory.getLogger(SymmMatrix.class);
    transient MemorySegment memory;
    final int n;
    final UPLO uplo;

    SymmMatrix() {
        this.memory = null;
        this.n = 0;
        this.uplo = null;
    }

    SymmMatrix(MemorySegment memory, UPLO uplo, int n) {
        if (n <= 0) {
            throw new IllegalArgumentException(String.format("Invalid matrix size: %d x %d", n, n));
        }
        this.memory = memory;
        this.uplo = uplo;
        this.n = n;
    }

    public static SymmMatrix zeros(ScalarType scalarType, UPLO uplo, int n) {
        if (uplo == null) {
            throw new IllegalArgumentException("UPLO is null");
        }
        return switch (scalarType) {
            case ScalarType.Float64 -> {
                double[] AP = new double[n * (n + 1) / 2];
                yield new SymmMatrix64(uplo, n, AP);
            }
            case ScalarType.Float32 -> {
                float[] AP = new float[n * (n + 1) / 2];
                yield new SymmMatrix32(uplo, n, AP);
            }
            default -> throw new UnsupportedOperationException("Unsupported ScalarType: " + String.valueOf((Object)scalarType));
        };
    }

    public static SymmMatrix of(DenseMatrix A) {
        if (!A.isSymmetric()) {
            throw new IllegalArgumentException("The input matrix is not symmetric");
        }
        int n = A.ncol();
        UPLO uplo = A.uplo();
        SymmMatrix matrix = SymmMatrix.zeros(A.scalarType(), uplo, n);
        switch (uplo) {
            case LOWER: {
                for (int i = 0; i < n; ++i) {
                    for (int j = 0; j <= i; ++j) {
                        matrix.set(i, j, A.get(i, j));
                    }
                }
                break;
            }
            case UPPER: {
                for (int i = 0; i < n; ++i) {
                    for (int j = i; j < n; ++j) {
                        matrix.set(i, j, A.get(i, j));
                    }
                }
                break;
            }
        }
        return matrix;
    }

    public static SymmMatrix of(UPLO uplo, double[][] AP) {
        int n = AP.length;
        SymmMatrix matrix = SymmMatrix.zeros(ScalarType.Float64, uplo, n);
        switch (uplo) {
            case LOWER: {
                for (int i = 0; i < n; ++i) {
                    for (int j = 0; j <= i; ++j) {
                        matrix.set(i, j, AP[i][j]);
                    }
                }
                break;
            }
            case UPPER: {
                for (int i = 0; i < n; ++i) {
                    for (int j = i; j < n; ++j) {
                        matrix.set(i, j, AP[i][j]);
                    }
                }
                break;
            }
        }
        return matrix;
    }

    public static SymmMatrix of(UPLO uplo, float[][] AP) {
        int n = AP.length;
        SymmMatrix matrix = SymmMatrix.zeros(ScalarType.Float32, uplo, n);
        switch (uplo) {
            case LOWER: {
                for (int i = 0; i < n; ++i) {
                    for (int j = 0; j <= i; ++j) {
                        matrix.set(i, j, AP[i][j]);
                    }
                }
                break;
            }
            case UPPER: {
                for (int i = 0; i < n; ++i) {
                    for (int j = i; j < n; ++j) {
                        matrix.set(i, j, AP[i][j]);
                    }
                }
                break;
            }
        }
        return matrix;
    }

    @Override
    public int nrow() {
        return this.n;
    }

    @Override
    public int ncol() {
        return this.n;
    }

    @Override
    public SymmMatrix scale(double alpha) {
        switch (this.scalarType()) {
            case Float64: {
                cblas_h.cblas_dscal((int)this.length(), alpha, this.memory, 1);
                break;
            }
            case Float32: {
                cblas_h.cblas_sscal((int)this.length(), (float)alpha, this.memory, 1);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        return this;
    }

    @Override
    public abstract SymmMatrix copy();

    @Override
    public SymmMatrix transpose() {
        return this;
    }

    public Order order() {
        return Order.COL_MAJOR;
    }

    public UPLO uplo() {
        return this.uplo;
    }

    public boolean equals(Object o) {
        double tol = 10.0f * MathEx.FLOAT_EPSILON;
        if (o instanceof SymmMatrix) {
            SymmMatrix b = (SymmMatrix)o;
            if (this.nrow() == b.nrow()) {
                for (int j = 0; j < this.n; ++j) {
                    for (int i = 0; i <= j; ++i) {
                        if (!(Math.abs(this.get(i, j) - b.get(i, j)) > tol)) continue;
                        return false;
                    }
                }
                return true;
            }
        }
        return false;
    }

    @Override
    public void mv(Transpose trans, double alpha, Vector x, double beta, Vector y) {
        if (this.scalarType() != x.scalarType()) {
            throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)this.scalarType()) + " != " + String.valueOf((Object)x.scalarType()));
        }
        if (this.scalarType() != y.scalarType()) {
            throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)this.scalarType()) + " != " + String.valueOf((Object)y.scalarType()));
        }
        switch (this.scalarType()) {
            case Float64: {
                cblas_h.cblas_dspmv(this.order().blas(), this.uplo.blas(), this.n, alpha, this.memory, x.memory, 1, beta, y.memory, 1);
                break;
            }
            case Float32: {
                cblas_h.cblas_sspmv(this.order().blas(), this.uplo.blas(), this.n, (float)alpha, this.memory, x.memory, 1, (float)beta, y.memory, 1);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
    }

    public Vector solve(double[] b) {
        Vector x = this.vector(this.n);
        for (int i = 0; i < this.n; ++i) {
            x.set(i, b[i]);
        }
        this.solve(x);
        return x;
    }

    public Vector solve(float[] b) {
        Vector x = this.vector(this.n);
        for (int i = 0; i < this.n; ++i) {
            x.set(i, b[i]);
        }
        this.solve(x);
        return x;
    }

    public void solve(DenseMatrix B) {
        if (B.m != this.n) {
            throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", this.n, this.n, B.m, B.n));
        }
        SymmMatrix lu = this.copy();
        byte[] uplo = new byte[]{this.uplo.lapack()};
        int[] n = new int[]{lu.n};
        int[] ipiv = new int[lu.n];
        int[] info = new int[]{0};
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment uplo_ = MemorySegment.ofArray(uplo);
        MemorySegment ipiv_ = MemorySegment.ofArray(ipiv);
        MemorySegment info_ = MemorySegment.ofArray(info);
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dsptrf_(uplo_, n_, lu.memory, ipiv_, info_);
                break;
            }
            case Float32: {
                clapack_h.ssptrf_(uplo_, n_, lu.memory, ipiv_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] < 0) {
            logger.error("LAPACK SPTRF error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK SPTRF error code: " + info[0]);
        }
        if (info[0] > 0) {
            throw new RuntimeException("The matrix is singular.");
        }
        int[] nrhs = new int[]{B.n};
        int[] ldb = new int[]{B.ld};
        MemorySegment nrhs_ = MemorySegment.ofArray(nrhs);
        MemorySegment ldb_ = MemorySegment.ofArray(ldb);
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dsptrs_(uplo_, n_, nrhs_, lu.memory, ipiv_, B.memory, ldb_, info_);
                break;
            }
            case Float32: {
                clapack_h.ssptrs_(uplo_, n_, nrhs_, lu.memory, ipiv_, B.memory, ldb_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK SPTRS error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK SPTRS error code: " + info[0]);
        }
    }
}

