/*
 * Decompiled with CFR 0.152.
 */
package smile.tensor;

import java.io.Serializable;
import java.lang.foreign.MemorySegment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.linalg.Diag;
import smile.linalg.Side;
import smile.linalg.Transpose;
import smile.linalg.UPLO;
import smile.linalg.lapack.clapack_h;
import smile.tensor.Cholesky;
import smile.tensor.DenseMatrix;
import smile.tensor.Vector;

public record QR(DenseMatrix qr, Vector tau) implements Serializable
{
    private static final Logger logger = LoggerFactory.getLogger(QR.class);

    public Cholesky toCholesky() {
        int n = this.qr.n;
        DenseMatrix L = this.qr.zeros(n, n);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j <= i; ++j) {
                L.set(i, j, this.qr.get(j, i));
            }
        }
        L.withUplo(UPLO.LOWER);
        return new Cholesky(L);
    }

    public DenseMatrix R() {
        int n = this.qr.n;
        DenseMatrix R = this.tau.diagflat();
        for (int i = 0; i < n; ++i) {
            for (int j = i; j < n; ++j) {
                R.set(i, j, this.qr.get(i, j));
            }
        }
        return R;
    }

    public DenseMatrix Q() {
        DenseMatrix Q = this.qr.copy();
        Vector work = Q.vector(this.qr.n);
        int[] m = new int[]{this.qr.m};
        int[] n = new int[]{this.qr.n};
        int[] k = new int[]{Math.min(this.qr.m, this.qr.n)};
        int[] lda = new int[]{Q.ld};
        int[] lwork = new int[]{work.size()};
        int[] info = new int[]{0};
        MemorySegment m_ = MemorySegment.ofArray(m);
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment k_ = MemorySegment.ofArray(k);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment lwork_ = MemorySegment.ofArray(lwork);
        MemorySegment info_ = MemorySegment.ofArray(info);
        switch (Q.scalarType()) {
            case Float64: {
                clapack_h.dorgqr_(m_, n_, k_, Q.memory, lda_, this.tau.memory, work.memory, lwork_, info_);
                break;
            }
            case Float32: {
                clapack_h.sorgqr_(m_, n_, k_, Q.memory, lda_, this.tau.memory, work.memory, lwork_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)Q.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK ORGRQ error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK ORGRQ error code: " + info[0]);
        }
        return Q;
    }

    public Vector solve(double[] b) {
        Vector x = this.qr.vector(this.qr.m);
        for (int i = 0; i < this.qr.m; ++i) {
            x.set(i, b[i]);
        }
        this.solve(x);
        return x.slice(0, this.qr.n);
    }

    public Vector solve(float[] b) {
        Vector x = this.qr.vector(this.qr.m);
        for (int i = 0; i < this.qr.m; ++i) {
            x.set(i, b[i]);
        }
        this.solve(x);
        return x.slice(0, this.qr.n);
    }

    public void solve(DenseMatrix B) {
        if (this.qr.scalarType() != B.scalarType()) {
            throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)B.scalarType()) + " != " + String.valueOf((Object)this.qr.scalarType()));
        }
        if (this.qr.m != B.m) {
            throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", this.qr.nrow(), this.qr.nrow(), B.nrow(), B.ncol()));
        }
        Vector work = this.qr.vector(1);
        byte[] side = new byte[]{Side.LEFT.lapack()};
        byte[] trans = new byte[]{Transpose.TRANSPOSE.lapack()};
        byte[] unit = new byte[]{Diag.NON_UNIT.lapack()};
        int[] m = new int[]{B.m};
        int[] n = new int[]{B.n};
        int[] k = new int[]{Math.min(this.qr.m, this.qr.n)};
        int[] nrhs = new int[]{B.n};
        int[] lda = new int[]{this.qr.ld};
        int[] ldb = new int[]{B.ld};
        int[] lwork = new int[]{-1};
        int[] info = new int[]{0};
        MemorySegment side_ = MemorySegment.ofArray(side);
        MemorySegment trans_ = MemorySegment.ofArray(trans);
        MemorySegment unit_ = MemorySegment.ofArray(unit);
        MemorySegment m_ = MemorySegment.ofArray(m);
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment k_ = MemorySegment.ofArray(k);
        MemorySegment nrhs_ = MemorySegment.ofArray(nrhs);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment ldb_ = MemorySegment.ofArray(ldb);
        MemorySegment lwork_ = MemorySegment.ofArray(lwork);
        MemorySegment info_ = MemorySegment.ofArray(info);
        switch (this.qr.scalarType()) {
            case Float64: {
                clapack_h.dormqr_(side_, trans_, m_, n_, k_, this.qr.memory, lda_, this.tau.memory, B.memory, ldb_, work.memory, lwork_, info_);
                break;
            }
            case Float32: {
                clapack_h.sormqr_(side_, trans_, m_, n_, k_, this.qr.memory, lda_, this.tau.memory, B.memory, ldb_, work.memory, lwork_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.qr.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK ORMQR error code: {}", (Object)info[0]);
            throw new IllegalArgumentException("LAPACK ORMQR error code: " + info[0]);
        }
        work = this.qr.vector((int)work.get(0));
        lwork[0] = work.size();
        switch (this.qr.scalarType()) {
            case Float64: {
                clapack_h.dormqr_(side_, trans_, m_, n_, k_, this.qr.memory, lda_, this.tau.memory, B.memory, ldb_, work.memory, lwork_, info_);
                break;
            }
            case Float32: {
                clapack_h.sormqr_(side_, trans_, m_, n_, k_, this.qr.memory, lda_, this.tau.memory, B.memory, ldb_, work.memory, lwork_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.qr.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK ORMQR error code: {}", (Object)info[0]);
            throw new IllegalArgumentException("LAPACK ORMQR error code: " + info[0]);
        }
        side[0] = UPLO.UPPER.lapack();
        trans[0] = Transpose.NO_TRANSPOSE.lapack();
        m[0] = this.qr.m;
        n[0] = this.qr.n;
        switch (this.qr.scalarType()) {
            case Float64: {
                clapack_h.dtrtrs_(side_, trans_, unit_, n_, nrhs_, this.qr.memory, lda_, B.memory, ldb_, info_);
                break;
            }
            case Float32: {
                clapack_h.strtrs_(side_, trans_, unit_, n_, nrhs_, this.qr.memory, lda_, B.memory, ldb_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.qr.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK TRTRS error code: {}", (Object)info[0]);
            throw new IllegalArgumentException("LAPACK TRTRS error code: " + info[0]);
        }
    }
}

