/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import java.util.Arrays;
import java.util.stream.Collectors;
import smile.math.MathEx;
import smile.stat.distribution.DiscreteDistribution;

public class DiscreteMixture
extends DiscreteDistribution {
    private static final long serialVersionUID = 2L;
    public final Component[] components;

    public DiscreteMixture(Component ... components) {
        if (components.length == 0) {
            throw new IllegalStateException("Empty mixture!");
        }
        this.components = components;
        double sum = 0.0;
        for (Component component : components) {
            sum += component.priori;
        }
        if (Math.abs(sum - 1.0) > 0.001) {
            throw new IllegalArgumentException("The sum of priori is not equal to 1.");
        }
    }

    public double[] posteriori(int x) {
        int k = this.components.length;
        double[] prob = new double[k];
        for (int i = 0; i < k; ++i) {
            Component c = this.components[i];
            prob[i] = c.priori * c.distribution.p(x);
        }
        double p = MathEx.sum(prob);
        int i = 0;
        while (i < k) {
            int n = i++;
            prob[n] = prob[n] / p;
        }
        return prob;
    }

    public int map(int x) {
        int k = this.components.length;
        double[] prob = new double[k];
        for (int i = 0; i < k; ++i) {
            Component c = this.components[i];
            prob[i] = c.priori * c.distribution.p(x);
        }
        return MathEx.whichMax(prob);
    }

    @Override
    public double mean() {
        double mu = 0.0;
        for (Component c : this.components) {
            mu += c.priori * c.distribution.mean();
        }
        return mu;
    }

    @Override
    public double variance() {
        double variance = 0.0;
        for (Component c : this.components) {
            variance += c.priori * c.priori * c.distribution.variance();
        }
        return variance;
    }

    @Override
    public double entropy() {
        throw new UnsupportedOperationException("Mixture does not support entropy()");
    }

    @Override
    public double p(int x) {
        double p = 0.0;
        for (Component c : this.components) {
            p += c.priori * c.distribution.p(x);
        }
        return p;
    }

    @Override
    public double logp(int x) {
        return Math.log(this.p(x));
    }

    @Override
    public double cdf(double x) {
        double p = 0.0;
        for (Component c : this.components) {
            p += c.priori * c.distribution.cdf(x);
        }
        return p;
    }

    @Override
    public double rand() {
        double r = MathEx.random();
        double p = 0.0;
        for (Component g : this.components) {
            if (!(r <= (p += g.priori))) continue;
            return g.distribution.rand();
        }
        throw new IllegalStateException();
    }

    @Override
    public double quantile(double p) {
        int xu;
        int xl;
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Invalid p: " + p);
        }
        int inc = 1;
        int x = (int)this.mean();
        if (p < this.cdf(x)) {
            while (p < this.cdf(x -= (inc *= 2))) {
            }
            xl = x;
            xu = x + inc / 2;
        } else {
            while (p > this.cdf(x += (inc *= 2))) {
            }
            xu = x;
            xl = x - inc / 2;
        }
        return this.quantile(p, xl, xu);
    }

    @Override
    public int length() {
        int length = 0;
        for (Component component : this.components) {
            length += component.distribution.length();
        }
        return length;
    }

    public int size() {
        return this.components.length;
    }

    public double bic(double[] data) {
        int n = data.length;
        double logLikelihood = 0.0;
        for (double x : data) {
            double p = this.p(x);
            if (!(p > 0.0)) continue;
            logLikelihood += Math.log(p);
        }
        return logLikelihood - 0.5 * (double)this.length() * Math.log(n);
    }

    public String toString() {
        return Arrays.stream(this.components).map(component -> String.format("%.2f x %s", component.priori, component.distribution)).collect(Collectors.joining(" + ", String.format("Mixture(%d)[", this.components.length), "]"));
    }

    public record Component(double priori, DiscreteDistribution distribution) {
    }
}

