/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.IntStream;
import smile.graph.NearestNeighborGraph;
import smile.math.MathEx;
import smile.neighbor.KNNSearch;
import smile.neighbor.Neighbor;
import smile.neighbor.NeighborBuilder;
import smile.neighbor.RandomProjectionTree;
import smile.sort.HeapSelect;

public class RandomProjectionForest
implements KNNSearch<double[], double[]> {
    private final List<FlatTree> trees;
    private final double[][] data;
    private final boolean angular;

    private RandomProjectionForest(List<FlatTree> trees, double[][] data, boolean angular) {
        this.trees = trees;
        this.data = data;
        this.angular = angular;
    }

    @Override
    public Neighbor<double[], double[]>[] search(double[] q, int k) {
        HeapSelect heap = new HeapSelect(NeighborBuilder.class, k);
        for (int i = 0; i < k; ++i) {
            heap.add(new NeighborBuilder());
        }
        HashSet<Integer> uniqueSamples = new HashSet<Integer>();
        for (FlatTree tree : this.trees) {
            int[] samples;
            for (int index : samples = tree.search(q)) {
                uniqueSamples.add(index);
            }
        }
        for (Integer index : uniqueSamples) {
            double dist;
            double[] x = this.data[index];
            double d = dist = this.angular ? MathEx.angular(q, x) : MathEx.distance(q, x);
            if (heap.size() < k) {
                heap.add(new NeighborBuilder<double[], double[]>(x, x, index, dist));
                continue;
            }
            NeighborBuilder top = (NeighborBuilder)heap.peek();
            if (!(dist < top.distance)) continue;
            top.distance = dist;
            top.index = index;
            top.key = x;
            top.value = x;
            heap.siftDown();
        }
        heap.sort();
        return (Neighbor[])Arrays.stream((NeighborBuilder[])heap.toArray()).map(NeighborBuilder::toNeighbor).toArray(Neighbor[]::new);
    }

    public NearestNeighborGraph toGraph(int k) {
        int n = this.data.length;
        ArrayList heapList = new ArrayList(n);
        ArrayList neighborSetList = new ArrayList(n);
        for (int i = 0; i < this.data.length; ++i) {
            heapList.add(new HeapSelect(NeighborBuilder.class, k));
            neighborSetList.add(new HashSet());
        }
        for (FlatTree tree : this.trees) {
            for (int[] leaf : tree.indices) {
                for (int li = 0; li < leaf.length; ++li) {
                    int i = leaf[li];
                    double[] xi = this.data[i];
                    for (int lj = li + 1; lj < leaf.length; ++lj) {
                        int j = leaf[lj];
                        double[] xj = this.data[j];
                        double dist = this.angular ? MathEx.angular(xi, xj) : MathEx.distance(xi, xj);
                        RandomProjectionForest.update((Set)neighborSetList.get(i), (HeapSelect)heapList.get(i), k, xj, j, dist);
                        RandomProjectionForest.update((Set)neighborSetList.get(j), (HeapSelect)heapList.get(j), k, xi, i, dist);
                    }
                }
            }
        }
        int[][] neighbors = new int[n][];
        double[][] distances = new double[n][];
        for (int i = 0; i < n; ++i) {
            HeapSelect pq = (HeapSelect)heapList.get(i);
            int m = Math.min(k, pq.size());
            neighbors[i] = new int[m];
            distances[i] = new double[m];
            pq.sort();
            NeighborBuilder[] a = (NeighborBuilder[])pq.toArray();
            int j = 0;
            int l = m - 1;
            while (j < m) {
                neighbors[i][j] = a[l].index;
                distances[i][j] = a[l].distance;
                ++j;
                --l;
            }
        }
        return new NearestNeighborGraph(k, neighbors, distances);
    }

    private static void update(Set<Integer> set, HeapSelect<NeighborBuilder<double[], double[]>> pq, int k, double[] x, int index, double dist) {
        if (!set.contains(index)) {
            if (pq.size() < k) {
                pq.add(new NeighborBuilder<double[], double[]>(x, x, index, dist));
                set.add(index);
            } else {
                NeighborBuilder<double[], double[]> top = pq.peek();
                if (dist < top.distance) {
                    set.remove(top.index);
                    set.add(index);
                    top.distance = dist;
                    top.index = index;
                    top.key = x;
                    top.value = x;
                    pq.siftDown();
                }
            }
        }
    }

    public static RandomProjectionForest of(double[][] data, int numTrees, int leafSize, boolean angular) {
        List<FlatTree> trees = IntStream.range(0, numTrees).parallel().mapToObj(i -> RandomProjectionTree.of(data, leafSize, angular).flatten()).toList();
        return new RandomProjectionForest(trees, data, angular);
    }

    record FlatTree(double[][] hyperplanes, double[] offsets, int[][] children, int[][] indices) {
        public int[] search(double[] point) {
            int node = 0;
            while (this.children[node][0] > 0) {
                boolean rightSide = RandomProjectionTree.isRightSide(point, this.hyperplanes[node], this.offsets[node]);
                node = this.children[node][rightSide ? 1 : 0];
            }
            return this.indices[-this.children[node][0]];
        }
    }
}

