/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.slm;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.logging.Logger;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.VariableIDInfo;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel;

public class SparseLinearModel
extends SkeletalIndependentRegressionSparseModel {
    private static final long serialVersionUID = 3L;
    private static final Logger logger = Logger.getLogger(SparseLinearModel.class.getName());
    private final SparseVector[] weights;
    private final DenseVector featureMeans;
    private final DenseVector featureVariance;
    private final boolean bias;
    private final double[] yMean;
    private final double[] yVariance;

    SparseLinearModel(String name, String[] dimensionNames, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, SparseVector[] weights, DenseVector featureMeans, DenseVector featureVariance, double[] yMean, double[] yVariance, boolean bias) {
        super(name, dimensionNames, description, featureIDMap, labelIDMap, SparseLinearModel.generateActiveFeatures(dimensionNames, featureIDMap, weights));
        this.weights = weights;
        this.featureMeans = featureMeans;
        this.featureVariance = featureVariance;
        this.bias = bias;
        this.yVariance = yVariance;
        this.yMean = yMean;
    }

    private static Map<String, List<String>> generateActiveFeatures(String[] dimensionNames, ImmutableFeatureMap featureMap, SparseVector[] weightsArray) {
        HashMap<String, List<String>> map = new HashMap<String, List<String>>();
        for (int i = 0; i < dimensionNames.length; ++i) {
            ArrayList<String> featureNames = new ArrayList<String>();
            for (VectorTuple v : weightsArray[i]) {
                if (v.index == featureMap.size()) {
                    featureNames.add("BIAS");
                    continue;
                }
                VariableIDInfo info = featureMap.get(v.index);
                featureNames.add(info.getName());
            }
            map.put(dimensionNames[i], featureNames);
        }
        return map;
    }

    protected SparseVector createFeatures(Example<Regressor> example) {
        SparseVector features = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)this.bias);
        features.intersectAndAddInPlace((Tensor)this.featureMeans, a -> -a);
        features.hadamardProductInPlace((Tensor)this.featureVariance, a -> 1.0 / a);
        return features;
    }

    protected Regressor.DimensionTuple scoreDimension(int dimensionIdx, SparseVector features) {
        double prediction = this.weights[dimensionIdx].numActiveElements() > 0 ? this.weights[dimensionIdx].dot((SGDVector)features) : 1.0;
        prediction *= this.yVariance[dimensionIdx];
        return new Regressor.DimensionTuple(this.dimensions[dimensionIdx], prediction += this.yMean[dimensionIdx]);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        int maxFeatures = n < 0 ? this.featureIDMap.size() + 1 : n;
        Comparator<Pair> comparator = Comparator.comparingDouble(p -> Math.abs((Double)p.getB()));
        HashMap<String, List<Pair<String, Double>>> map = new HashMap<String, List<Pair<String, Double>>>();
        PriorityQueue<Pair> q = new PriorityQueue<Pair>(maxFeatures, comparator);
        for (int i = 0; i < this.dimensions.length; ++i) {
            q.clear();
            for (VectorTuple v : this.weights[i]) {
                VariableIDInfo info = this.featureIDMap.get(v.index);
                String name = info == null ? "BIAS" : info.getName();
                Pair curr = new Pair((Object)name, (Object)v.value);
                if (q.size() < maxFeatures) {
                    q.offer(curr);
                    continue;
                }
                if (comparator.compare(curr, q.peek()) <= 0) continue;
                q.poll();
                q.offer(curr);
            }
            ArrayList<Pair> b = new ArrayList<Pair>();
            while (q.size() > 0) {
                b.add(q.poll());
            }
            Collections.reverse(b);
            map.put(this.dimensions[i], b);
        }
        return map;
    }

    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
        Prediction prediction = this.predict(example);
        HashMap weightMap = new HashMap();
        SparseVector features = this.createFeatures(example);
        for (int i = 0; i < this.dimensions.length; ++i) {
            ArrayList<Pair> classScores = new ArrayList<Pair>();
            for (VectorTuple f : features) {
                double score = this.weights[i].get(f.index) * f.value;
                classScores.add(new Pair((Object)this.featureIDMap.get(f.index).getName(), (Object)score));
            }
            classScores.sort((o1, o2) -> ((Double)o2.getB()).compareTo((Double)o1.getB()));
            weightMap.put(this.dimensions[i], classScores);
        }
        return Optional.of(new Excuse(example, prediction, weightMap));
    }

    protected Model<Regressor> copy(String newName, ModelProvenance newProvenance) {
        return new SparseLinearModel(newName, Arrays.copyOf(this.dimensions, this.dimensions.length), newProvenance, this.featureIDMap, (ImmutableOutputInfo<Regressor>)this.outputIDInfo, this.copyWeights(), this.featureMeans.copy(), this.featureVariance.copy(), Arrays.copyOf(this.yMean, this.yMean.length), Arrays.copyOf(this.yVariance, this.yVariance.length), this.bias);
    }

    private SparseVector[] copyWeights() {
        SparseVector[] newWeights = new SparseVector[this.weights.length];
        for (int i = 0; i < this.weights.length; ++i) {
            newWeights[i] = this.weights[i].copy();
        }
        return newWeights;
    }

    public Map<String, SparseVector> getWeights() {
        SparseVector[] newWeights = this.copyWeights();
        HashMap<String, SparseVector> output = new HashMap<String, SparseVector>();
        for (int i = 0; i < this.dimensions.length; ++i) {
            output.put(this.dimensions[i], newWeights[i]);
        }
        return output;
    }
}

