/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.slm;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.Arrays;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.SparseModel;
import org.tribuo.SparseTrainer;
import org.tribuo.Trainer;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.slm.SparseLinearModel;
import org.tribuo.util.Util;

public class ElasticNetCDTrainer
implements SparseTrainer<Regressor> {
    private static final Logger logger = Logger.getLogger(ElasticNetCDTrainer.class.getName());
    @Config(mandatory=true, description="Overall regularisation penalty.")
    private double alpha;
    @Config(mandatory=true, description="Ratio of l1 to l2 parameters.")
    private double l1Ratio;
    @Config(description="Tolerance on the error.")
    private double tolerance = 1.0E-4;
    @Config(description="Maximium number of iterations to run.")
    private int maxIterations = 500;
    @Config(description="Randomises the order in which the features are probed.")
    private boolean randomise = false;
    @Config(description="The seed for the RNG.")
    private long seed = 12345L;
    private SplittableRandom rng;
    private int trainInvocationCounter;

    private ElasticNetCDTrainer() {
    }

    public ElasticNetCDTrainer(double alpha, double l1Ratio) {
        this(alpha, l1Ratio, 1.0E-4, 500, false, 12345L);
    }

    public ElasticNetCDTrainer(double alpha, double l1Ratio, long seed) {
        this(alpha, l1Ratio, 1.0E-4, 500, true, seed);
    }

    public ElasticNetCDTrainer(double alpha, double l1Ratio, double tolerance, int maxIterations, boolean randomise, long seed) {
        this.alpha = alpha;
        this.l1Ratio = l1Ratio;
        this.tolerance = tolerance;
        this.maxIterations = maxIterations;
        this.randomise = randomise;
        this.seed = seed;
        this.postConfig();
    }

    public synchronized void postConfig() {
        if (this.l1Ratio < 1.0E-12 || this.l1Ratio > 1.000000000001) {
            throw new PropertyException("l1Ratio", "L1 Ratio must be between 0 and 1. Found value " + this.l1Ratio);
        }
        this.rng = new SplittableRandom(this.seed);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public SparseModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
        int i;
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        ElasticNetCDTrainer elasticNetCDTrainer = this;
        synchronized (elasticNetCDTrainer) {
            localRNG = this.rng.split();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
        ImmutableOutputInfo outputInfo = examples.getOutputIDInfo();
        int numFeatures = featureIDMap.size();
        int numOutputs = outputInfo.size();
        int numExamples = examples.size();
        SparseVector[] columns = SparseVector.transpose(examples, (ImmutableFeatureMap)featureIDMap);
        String[] dimensionNames = new String[numOutputs];
        DenseVector[] regressionTargets = new DenseVector[numOutputs];
        for (i = 0; i < numOutputs; ++i) {
            dimensionNames[i] = ((Regressor)outputInfo.getOutput(i)).getNames()[0];
            regressionTargets[i] = new DenseVector(numExamples);
        }
        i = 0;
        for (Example e : examples) {
            int j = 0;
            for (Regressor.DimensionTuple d : (Regressor)e.getOutput()) {
                regressionTargets[j].set(i, d.getValue());
                ++j;
            }
            ++i;
        }
        double l1Penalty = this.alpha * this.l1Ratio * (double)numExamples;
        double l2Penalty = this.alpha * (1.0 - this.l1Ratio) * (double)numExamples;
        double[] featureMeans = ElasticNetCDTrainer.calculateMeans((SGDVector[])columns);
        double[] featureVariances = new double[columns.length];
        Arrays.fill(featureVariances, 1.0);
        boolean center = false;
        for (i = 0; i < numFeatures; ++i) {
            if (!(Math.abs(featureMeans[i]) > 1.0E-12)) continue;
            center = true;
            break;
        }
        double[] columnNorms = new double[numFeatures];
        int[] featureIndices = new int[numFeatures];
        for (i = 0; i < numFeatures; ++i) {
            featureIndices[i] = i;
            double variance = 0.0;
            for (VectorTuple v : columns[i]) {
                variance += (v.value - featureMeans[i]) * (v.value - featureMeans[i]);
            }
            columnNorms[i] = variance + (double)(numExamples - columns[i].numActiveElements()) * featureMeans[i] * featureMeans[i];
        }
        ElasticNetState elState = new ElasticNetState(columns, featureIndices, featureMeans, columnNorms, l1Penalty, l2Penalty, center);
        SparseVector[] outputWeights = new SparseVector[numOutputs];
        double[] outputMeans = new double[numOutputs];
        for (int j = 0; j < dimensionNames.length; ++j) {
            outputWeights[j] = this.trainSingleDimension(regressionTargets[j], elState, localRNG.split());
            outputMeans[j] = regressionTargets[j].sum() / (double)numExamples;
        }
        double[] outputVariances = new double[numOutputs];
        Arrays.fill(outputVariances, 1.0);
        ModelProvenance provenance = new ModelProvenance(SparseLinearModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        return new SparseLinearModel("elastic-net-model", dimensionNames, provenance, featureIDMap, (ImmutableOutputInfo<Regressor>)outputInfo, outputWeights, DenseVector.createDenseVector((double[])featureMeans), DenseVector.createDenseVector((double[])featureVariances), outputMeans, outputVariances, false);
    }

    private SparseVector trainSingleDimension(DenseVector regressionTargets, ElasticNetState state, SplittableRandom localRNG) {
        int numFeatures = state.numFeatures;
        int numExamples = state.numExamples;
        DenseVector residuals = regressionTargets.copy();
        DenseVector weights = new DenseVector(numFeatures);
        double targetTwoNorm = regressionTargets.twoNorm();
        double newTolerance = this.tolerance * targetTwoNorm * targetTwoNorm;
        double[] xTransposeR = new double[numFeatures];
        double[] xTransposeAlpha = new double[numFeatures];
        for (int i = 0; i < this.maxIterations; ++i) {
            double dualityGap;
            double scalingFactor;
            double maxWeight = 0.0;
            double maxUpdate = 0.0;
            if (this.randomise) {
                Util.randpermInPlace((int[])state.featureIndices, (SplittableRandom)localRNG);
            }
            for (int j = 0; j < numFeatures; ++j) {
                double absNewWeight;
                double curUpdate;
                int feature = state.featureIndices[j];
                if (Math.abs(state.columnNorms[feature]) < 1.0E-12) continue;
                double oldWeight = weights.get(feature);
                if (oldWeight != 0.0) {
                    for (VectorTuple v : state.columns[feature]) {
                        residuals.set(v.index, residuals.get(v.index) + v.value * oldWeight);
                    }
                    if (state.center) {
                        for (int k = 0; k < numExamples; ++k) {
                            residuals.set(k, residuals.get(k) - state.featureMeans[feature] * oldWeight);
                        }
                    }
                }
                double curDot = residuals.dot((SGDVector)state.columns[feature]);
                if (state.center) {
                    curDot -= residuals.sum() * state.featureMeans[feature];
                }
                double newWeight = Math.signum(curDot) * Math.max(Math.abs(curDot) - state.l1Penalty, 0.0) / (state.columnNorms[feature] + state.l2Penalty);
                weights.set(feature, newWeight);
                if (newWeight != 0.0) {
                    for (VectorTuple v : state.columns[feature]) {
                        residuals.set(v.index, residuals.get(v.index) - v.value * newWeight);
                    }
                    if (state.center) {
                        for (int k = 0; k < numExamples; ++k) {
                            residuals.set(k, residuals.get(k) + state.featureMeans[feature] * newWeight);
                        }
                    }
                }
                if ((curUpdate = Math.abs(newWeight - oldWeight)) > maxUpdate) {
                    maxUpdate = curUpdate;
                }
                if (!((absNewWeight = Math.abs(newWeight)) > maxWeight)) continue;
                maxWeight = absNewWeight;
            }
            if (!(maxWeight < 1.0E-12) && !(maxUpdate / maxWeight < this.tolerance) && i != this.maxIterations - 1) continue;
            double residualSum = residuals.sum();
            double maxAbsXTA = 0.0;
            for (int j = 0; j < numFeatures; ++j) {
                xTransposeR[j] = residuals.dot((SGDVector)state.columns[j]);
                if (state.center) {
                    int n = j;
                    xTransposeR[n] = xTransposeR[n] - state.featureMeans[j] * residualSum;
                }
                xTransposeAlpha[j] = xTransposeR[j] - state.l2Penalty * weights.get(j);
                double curAbs = Math.abs(xTransposeAlpha[j]);
                if (!(curAbs > maxAbsXTA)) continue;
                maxAbsXTA = curAbs;
            }
            double residualTwoNorm = residuals.twoNorm();
            residualTwoNorm *= residualTwoNorm;
            double weightsTwoNorm = weights.twoNorm();
            weightsTwoNorm *= weightsTwoNorm;
            double weightsOneNorm = weights.oneNorm();
            if (maxAbsXTA > state.l1Penalty) {
                scalingFactor = state.l1Penalty / maxAbsXTA;
                double alphaNorm = residualTwoNorm * scalingFactor * scalingFactor;
                dualityGap = 0.5 * (residualTwoNorm + alphaNorm);
            } else {
                scalingFactor = 1.0;
                dualityGap = residualTwoNorm;
            }
            dualityGap += state.l1Penalty * weightsOneNorm - scalingFactor * residuals.dot((SGDVector)regressionTargets);
            dualityGap += 0.5 * state.l2Penalty * (1.0 + scalingFactor * scalingFactor) * weightsTwoNorm;
            if (!(dualityGap < newTolerance)) continue;
            logger.log(Level.INFO, "Iteration: " + i + ", duality gap = " + dualityGap + ", tolerance = " + newTolerance);
            break;
        }
        return weights.sparsify();
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public String toString() {
        return "ElasticNetCDTrainer(alpha=" + this.alpha + ",l1Ratio=" + this.l1Ratio + ",tolerance=" + this.tolerance + ",maxIterations=" + this.maxIterations + ",randomise=" + this.randomise + ",seed=" + this.seed + ")";
    }

    private static double[] calculateMeans(SGDVector[] columns) {
        double[] means = new double[columns.length];
        for (int i = 0; i < means.length; ++i) {
            means[i] = columns[i].sum() / (double)columns[i].size();
        }
        return means;
    }

    private static double[] calculateVariances(SGDVector[] columns, double[] means) {
        double[] variances = new double[columns.length];
        for (int i = 0; i < variances.length; ++i) {
            variances[i] = columns[i].variance(means[i]);
        }
        return variances;
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }

    private static class ElasticNetState {
        final SparseVector[] columns;
        final int numFeatures;
        final int numExamples;
        final int[] featureIndices;
        final double[] featureMeans;
        final double[] columnNorms;
        final double l1Penalty;
        final double l2Penalty;
        final boolean center;

        public ElasticNetState(SparseVector[] columns, int[] featureIndices, double[] featureMeans, double[] columnNorms, double l1Penalty, double l2Penalty, boolean center) {
            this.columns = columns;
            this.numFeatures = columns.length;
            this.numExamples = columns[0].size();
            this.featureIndices = featureIndices;
            this.featureMeans = featureMeans;
            this.columnNorms = columnNorms;
            this.l1Penalty = l1Penalty;
            this.l2Penalty = l2Penalty;
            this.center = center;
        }
    }
}

