/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import smile.math.MathEx;
import smile.stat.distribution.DiscreteDistribution;

public class HyperGeometricDistribution
extends DiscreteDistribution {
    private static final long serialVersionUID = 2L;
    public final int N;
    public final int m;
    public final int n;
    private RandomNumberGenerator rng;

    public HyperGeometricDistribution(int N, int m, int n) {
        if (N < 0) {
            throw new IllegalArgumentException("Invalid N: " + N);
        }
        if (m < 0 || m > N) {
            throw new IllegalArgumentException("Invalid m: " + m);
        }
        if (n < 0 || n > N) {
            throw new IllegalArgumentException("Invalid n: " + n);
        }
        this.N = N;
        this.m = m;
        this.n = n;
    }

    @Override
    public int length() {
        return 3;
    }

    @Override
    public double mean() {
        return (double)this.m * (double)this.n / (double)this.N;
    }

    @Override
    public double variance() {
        double r = (double)this.m / (double)this.N;
        return (double)(this.n * (this.N - this.n)) * r * (1.0 - r) / (double)(this.N - 1);
    }

    @Override
    public double entropy() {
        throw new UnsupportedOperationException("Hypergeometric distribution does not support entropy()");
    }

    public String toString() {
        return String.format("Hypergeometric Distribution(%d, %d, %d)", this.N, this.m, this.n);
    }

    @Override
    public double p(int k) {
        if (k < Math.max(0, this.m + this.n - this.N) || k > Math.min(this.m, this.n)) {
            return 0.0;
        }
        return Math.exp(this.logp(k));
    }

    @Override
    public double logp(int k) {
        if (k < Math.max(0, this.m + this.n - this.N) || k > Math.min(this.m, this.n)) {
            return Double.NEGATIVE_INFINITY;
        }
        return MathEx.lchoose(this.m, k) + MathEx.lchoose(this.N - this.m, this.n - k) - MathEx.lchoose(this.N, this.n);
    }

    @Override
    public double cdf(double k) {
        int L = Math.max(0, this.m + this.n - this.N);
        if (k < (double)L) {
            return 0.0;
        }
        if (k >= (double)Math.min(this.m, this.n)) {
            return 1.0;
        }
        double p = 0.0;
        int i = L;
        while ((double)i <= k) {
            p += this.p(i);
            ++i;
        }
        return p;
    }

    @Override
    public double quantile(double p) {
        int ku;
        int kl;
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Invalid p: " + p);
        }
        if (p == 0.0) {
            return Math.max(0, this.m + this.n - this.N);
        }
        if (p == 1.0) {
            return Math.min(this.m, this.n);
        }
        int inc = 1;
        int k = Math.max(0, Math.min(this.n, (int)((double)this.n * p)));
        if (p < this.cdf(k)) {
            do {
                k = Math.max(k - inc, 0);
                inc *= 2;
            } while (p < this.cdf(k) && k > 0);
            kl = k;
            ku = k + inc / 2;
        } else {
            do {
                k = Math.min(k + inc, this.n + 1);
                inc *= 2;
            } while (p > this.cdf(k));
            ku = k;
            kl = k - inc / 2;
        }
        return this.quantile(p, kl, ku);
    }

    @Override
    public double rand() {
        if (this.rng == null) {
            int mm = this.m;
            int nn = this.n;
            if (mm > this.N / 2) {
                mm = this.N - mm;
            }
            if (nn > this.N / 2) {
                nn = this.N - nn;
            }
            this.rng = (double)nn * (double)mm >= (double)(20 * this.N) ? new Patchwork(this.N, this.m, this.n) : new Inversion(this.N, this.m, this.n);
        }
        return this.rng.rand();
    }

    static abstract class RandomNumberGenerator {
        protected final int N;
        protected final int m;
        protected final int n;
        protected int fak = 1;
        protected int addd = 0;

        RandomNumberGenerator(int N, int m, int n) {
            if (m > N / 2) {
                m = N - m;
                this.fak = -1;
                this.addd = n;
            }
            if (n > N / 2) {
                n = N - n;
                this.addd += this.fak * m;
                this.fak = -this.fak;
            }
            if (n > m) {
                int swap = n;
                n = m;
                m = swap;
            }
            this.N = N;
            this.m = m;
            this.n = n;
        }

        public int rand() {
            if (this.n == 0) {
                return this.addd;
            }
            int x = this.random();
            return x * this.fak + this.addd;
        }

        protected abstract int random();
    }

    static class Patchwork
    extends RandomNumberGenerator {
        private final int L;
        private final int k1;
        private final int k2;
        private final int k4;
        private final int k5;
        private final double dl;
        private final double dr;
        private final double r1;
        private final double r2;
        private final double r4;
        private final double r5;
        private final double ll;
        private final double lr;
        private final double cPm;
        private final double f1;
        private final double f2;
        private final double f4;
        private final double f5;
        private final double p1;
        private final double p2;
        private final double p3;
        private final double p4;
        private final double p5;
        private final double p6;

        Patchwork(int N, int mm, int nn) {
            super(N, mm, nn);
            double Mp = this.m + 1;
            double np = this.n + 1;
            this.L = N - this.m - this.n;
            double p = Mp / ((double)N + 2.0);
            double modef = np * p;
            double U = Math.sqrt(modef * (1.0 - p) * (1.0 - ((double)this.n + 2.0) / ((double)N + 3.0)) + 0.25);
            int mode = (int)modef;
            int ceil = (int)Math.ceil(modef - 0.5 - U);
            this.k2 = ceil >= mode ? mode - 1 : ceil;
            this.k4 = (int)(modef - 0.5 + U);
            this.k1 = this.k2 + this.k2 - mode + 1;
            this.k5 = this.k4 + this.k4 - mode;
            this.dl = this.k2 - this.k1;
            this.dr = this.k5 - this.k4;
            this.r1 = (np / (double)this.k1 - 1.0) * (Mp - (double)this.k1) / (double)(this.L + this.k1);
            this.r2 = (np / (double)this.k2 - 1.0) * (Mp - (double)this.k2) / (double)(this.L + this.k2);
            this.r4 = (np / (double)(this.k4 + 1) - 1.0) * (double)(this.m - this.k4) / (double)(this.L + this.k4 + 1);
            this.r5 = (np / (double)(this.k5 + 1) - 1.0) * (double)(this.m - this.k5) / (double)(this.L + this.k5 + 1);
            this.ll = Math.log(this.r1);
            this.lr = -Math.log(this.r5);
            this.cPm = this.lnpk(mode, this.L, this.m, this.n);
            this.f2 = Math.exp(this.cPm - this.lnpk(this.k2, this.L, this.m, this.n));
            this.f4 = Math.exp(this.cPm - this.lnpk(this.k4, this.L, this.m, this.n));
            this.f1 = Math.exp(this.cPm - this.lnpk(this.k1, this.L, this.m, this.n));
            this.f5 = Math.exp(this.cPm - this.lnpk(this.k5, this.L, this.m, this.n));
            this.p1 = this.f2 * (this.dl + 1.0);
            this.p2 = this.f2 * this.dl + this.p1;
            this.p3 = this.f4 * (this.dr + 1.0) + this.p2;
            this.p4 = this.f4 * this.dr + this.p3;
            this.p5 = this.f1 / this.ll + this.p4;
            this.p6 = this.f5 / this.lr + this.p5;
        }

        @Override
        protected int random() {
            int X;
            while (true) {
                double Y;
                double d;
                double U = MathEx.random() * this.p6;
                if (d < this.p2) {
                    double d2;
                    double d3;
                    double d4;
                    W = U - this.p1;
                    if (d4 < 0.0) {
                        return this.k2 + (int)(U / this.f2);
                    }
                    Y = W / this.dl;
                    if (d3 < this.f1) {
                        return this.k1 + (int)(W / this.f1);
                    }
                    Dk = (int)(this.dl * MathEx.random()) + 1;
                    if (Y <= this.f2 - (double)Dk * (this.f2 - this.f2 / this.r2)) {
                        return this.k2 - Dk;
                    }
                    W = this.f2 + this.f2 - Y;
                    if (d2 < 1.0) {
                        V = this.k2 + Dk;
                        if (W <= this.f2 + (double)Dk * (1.0 - this.f2) / (this.dl + 1.0)) {
                            return V;
                        }
                        if (Math.log(W) <= this.cPm - this.lnpk(V, this.L, this.m, this.n)) {
                            return V;
                        }
                    }
                    X = this.k2 - Dk;
                } else if (U < this.p4) {
                    double d5;
                    double d6;
                    double d7;
                    W = U - this.p3;
                    if (d7 < 0.0) {
                        return this.k4 - (int)((U - this.p2) / this.f4);
                    }
                    Y = W / this.dr;
                    if (d6 < this.f5) {
                        return this.k5 - (int)(W / this.f5);
                    }
                    Dk = (int)(this.dr * MathEx.random()) + 1;
                    if (Y <= this.f4 - (double)Dk * (this.f4 - this.f4 * this.r4)) {
                        return this.k4 + Dk;
                    }
                    W = this.f4 + this.f4 - Y;
                    if (d5 < 1.0) {
                        V = this.k4 - Dk;
                        if (W <= this.f4 + (double)Dk * (1.0 - this.f4) / this.dr) {
                            return V;
                        }
                        if (Math.log(W) <= this.cPm - this.lnpk(V, this.L, this.m, this.n)) {
                            return V;
                        }
                    }
                    X = this.k4 + Dk;
                } else {
                    Y = MathEx.random();
                    if (U < this.p5) {
                        Dk = (int)(1.0 - Math.log(Y) / this.ll);
                        X = this.k1 - Dk;
                        if (X < 0) continue;
                        if ((Y *= (U - this.p4) * this.ll) <= this.f1 - (double)Dk * (this.f1 - this.f1 / this.r1)) {
                            return X;
                        }
                    } else {
                        Dk = (int)(1.0 - Math.log(Y) / this.lr);
                        X = this.k5 + Dk;
                        if (X > this.n) continue;
                        if ((Y *= (U - this.p5) * this.lr) <= this.f5 - (double)Dk * (this.f5 - this.f5 * this.r5)) {
                            return X;
                        }
                    }
                }
                if (Math.log(Y) <= this.cPm - this.lnpk(X, this.L, this.m, this.n)) break;
            }
            return X;
        }

        private double lnpk(int k, int L, int m, int n) {
            return MathEx.lfactorial(k) + MathEx.lfactorial(m - k) + MathEx.lfactorial(n - k) + MathEx.lfactorial(L + k);
        }
    }

    static class Inversion
    extends RandomNumberGenerator {
        private int mode;
        private final int mp;
        private int bound;
        private final double fm;

        Inversion(int N, int mm, int nn) {
            super(N, mm, nn);
            int L = N - this.m - this.n;
            double Mp = this.m + 1;
            double np = this.n + 1;
            double p = Mp / ((double)N + 2.0);
            double modef = np * p;
            this.mode = (int)modef;
            this.mp = (double)this.mode == modef && p == 0.5 ? this.mode-- : this.mode + 1;
            this.fm = Math.exp(MathEx.lfactorial(N - this.m) - MathEx.lfactorial(L + this.mode) - MathEx.lfactorial(this.n - this.mode) + MathEx.lfactorial(this.m) - MathEx.lfactorial(this.m - this.mode) - MathEx.lfactorial(this.mode) - MathEx.lfactorial(N) + MathEx.lfactorial(N - this.n) + MathEx.lfactorial(this.n));
            this.bound = (int)(modef + 11.0 * Math.sqrt(modef * (1.0 - p) * (1.0 - (double)this.n / (double)N) + 1.0));
            if (this.bound > this.n) {
                this.bound = this.n;
            }
        }

        @Override
        protected int random() {
            double L = this.N - this.m - this.n;
            double Mp = this.m + 1;
            double np = this.n + 1;
            block0: while (true) {
                double divisor;
                double d;
                double d2;
                double U = MathEx.random();
                U -= this.fm;
                if (d2 <= 0.0) {
                    return this.mode;
                }
                double c = d = this.fm;
                double k1 = this.mp - 1;
                double k2 = this.mode + 1;
                int i = 1;
                while (i <= this.mode) {
                    double d3;
                    double d4;
                    divisor = (np - k1) * (Mp - k1);
                    U *= divisor;
                    d *= divisor;
                    U -= (c *= k1 * (L + k1));
                    if (d4 <= 0.0) {
                        return this.mp - i - 1;
                    }
                    divisor = k2 * (L + k2);
                    U *= divisor;
                    c *= divisor;
                    U -= (d *= (np - k2) * (Mp - k2));
                    if (d3 <= 0.0) {
                        return this.mode + i;
                    }
                    if (U > 1.0E100) {
                        U *= 1.0E-100;
                        c *= 1.0E-100;
                        d *= 1.0E-100;
                    }
                    ++i;
                    k1 -= 1.0;
                    k2 += 1.0;
                }
                k2 = this.mp + this.mode;
                i = this.mp + this.mode;
                while (true) {
                    double d5;
                    if (i > this.bound) continue block0;
                    divisor = k2 * (L + k2);
                    U *= divisor;
                    U -= (d *= (np - k2) * (Mp - k2));
                    if (d5 <= 0.0) {
                        return i;
                    }
                    if (U > 1.0E100) {
                        U *= 1.0E-100;
                        d *= 1.0E-100;
                    }
                    ++i;
                    k2 += 1.0;
                }
                break;
            }
        }
    }
}

