/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import scala.Function0;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ScalaSignature;
import scala.runtime.DoubleRef;
import scala.runtime.java8.JFunction2;

@ScalaSignature(bytes="\u0006\u0001!4Q\u0001E\t\u0001+uA\u0001b\f\u0001\u0003\u0002\u0003\u0006I!\r\u0005\ti\u0001\u0011\t\u0011)A\u0005k!A\u0001\b\u0001B\u0001B\u0003%\u0011\b\u0003\u0005C\u0001\t\u0005\t\u0015!\u0003D\u0011\u0015Q\u0005\u0001\"\u0001L\u0011\u001d\t\u0006A1A\u0005RICaA\u0016\u0001!\u0002\u0013\u0019\u0006bB,\u0001\u0005\u0004%IA\u0015\u0005\u00071\u0002\u0001\u000b\u0011B*\t\u000fe\u0003!\u0019!C\u00055\"11\f\u0001Q\u0001\nUBq\u0001\u0018\u0001C\u0002\u0013%!\f\u0003\u0004^\u0001\u0001\u0006I!\u000e\u0005\t=\u0002A)\u0019!C\u0005?\")A\r\u0001C\u0001K\ny\u0001*\u001e2fe\u0006;wM]3hCR|'O\u0003\u0002\u0013'\u0005Q\u0011mZ4sK\u001e\fGo\u001c:\u000b\u0005Q)\u0012!B8qi&l'B\u0001\f\u0018\u0003\tiGN\u0003\u0002\u00193\u0005)1\u000f]1sW*\u0011!dG\u0001\u0007CB\f7\r[3\u000b\u0003q\t1a\u001c:h'\r\u0001a\u0004\n\t\u0003?\tj\u0011\u0001\t\u0006\u0002C\u0005)1oY1mC&\u00111\u0005\t\u0002\u0007\u0003:L(+\u001a4\u0011\t\u00152\u0003FL\u0007\u0002#%\u0011q%\u0005\u0002\u001d\t&4g-\u001a:f]RL\u0017M\u00197f\u0019>\u001c8/Q4he\u0016<\u0017\r^8s!\tIC&D\u0001+\u0015\tYS#A\u0004gK\u0006$XO]3\n\u00055R#\u0001C%ogR\fgnY3\u0011\u0005\u0015\u0002\u0011\u0001\u00044ji&sG/\u001a:dKB$8\u0001\u0001\t\u0003?IJ!a\r\u0011\u0003\u000f\t{w\u000e\\3b]\u00069Q\r]:jY>t\u0007CA\u00107\u0013\t9\u0004E\u0001\u0004E_V\u0014G.Z\u0001\u000eE\u000e4U-\u0019;ve\u0016\u001c8\u000b\u001e3\u0011\u0007ijt(D\u0001<\u0015\tat#A\u0005ce>\fGmY1ti&\u0011ah\u000f\u0002\n\u0005J|\u0017\rZ2bgR\u00042a\b!6\u0013\t\t\u0005EA\u0003BeJ\f\u00170\u0001\u0007cGB\u000b'/Y7fi\u0016\u00148\u000fE\u0002;{\u0011\u0003\"!\u0012%\u000e\u0003\u0019S!aR\u000b\u0002\r1Lg.\u00197h\u0013\tIeI\u0001\u0004WK\u000e$xN]\u0001\u0007y%t\u0017\u000e\u001e \u0015\t1su\n\u0015\u000b\u0003]5CQAQ\u0003A\u0002\rCQaL\u0003A\u0002EBQ\u0001N\u0003A\u0002UBQ\u0001O\u0003A\u0002e\n1\u0001Z5n+\u0005\u0019\u0006CA\u0010U\u0013\t)\u0006EA\u0002J]R\fA\u0001Z5nA\u0005Ya.^7GK\u0006$XO]3t\u00031qW/\u001c$fCR,(/Z:!\u0003\u0015\u0019\u0018nZ7b+\u0005)\u0014AB:jO6\f\u0007%A\u0005j]R,'oY3qi\u0006Q\u0011N\u001c;fe\u000e,\u0007\u000f\u001e\u0011\u0002\u0019\r|WM\u001a4jG&,g\u000e^:\u0016\u0003}B#AD1\u0011\u0005}\u0011\u0017BA2!\u0005%!(/\u00198tS\u0016tG/A\u0002bI\u0012$\"A\f4\t\u000b\u001d|\u0001\u0019\u0001\u0015\u0002\u0011%t7\u000f^1oG\u0016\u0004")
public class HuberAggregator
implements DifferentiableLossAggregator<Instance, HuberAggregator> {
    private transient double[] coefficients;
    private final boolean fitIntercept;
    private final double epsilon;
    private final Broadcast<double[]> bcFeaturesStd;
    private final Broadcast<Vector> bcParameters;
    private final int dim;
    private final int numFeatures;
    private final double sigma;
    private final double intercept;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient boolean bitmap$trans$0;
    private volatile boolean bitmap$0;

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        HuberAggregator huberAggregator = this;
        synchronized (huberAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    private double sigma() {
        return this.sigma;
    }

    private double intercept() {
        return this.intercept;
    }

    private double[] coefficients$lzycompute() {
        HuberAggregator huberAggregator = this;
        synchronized (huberAggregator) {
            if (!this.bitmap$trans$0) {
                this.coefficients = (double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(((Vector)this.bcParameters.value()).toArray())).slice(0, this.numFeatures());
                this.bitmap$trans$0 = true;
            }
        }
        return this.coefficients;
    }

    private double[] coefficients() {
        return !this.bitmap$trans$0 ? this.coefficients$lzycompute() : this.coefficients;
    }

    @Override
    public HuberAggregator add(Instance instance) {
        double weight;
        Instance instance2 = instance;
        if (instance2 != null) {
            double margin;
            double linearLoss;
            double label = instance2.label();
            weight = instance2.weight();
            Vector features = instance2.features();
            Predef$.MODULE$.require(this.numFeatures() == features.size(), (Function0 & java.io.Serializable & Serializable)() -> new StringBuilder(64).append("Dimensions mismatch when adding new sample.").append(" Expecting ").append(this.numFeatures()).append(" but got ").append(features.size()).append(".").toString());
            Predef$.MODULE$.require(weight >= 0.0, (Function0 & java.io.Serializable & Serializable)() -> new StringBuilder(34).append("instance weight, ").append(weight).append(" has to be >= 0.0").toString());
            if (weight == 0.0) {
                return this;
            }
            double[] localFeaturesStd = (double[])this.bcFeaturesStd.value();
            double[] localCoefficients = this.coefficients();
            double[] localGradientSumArray = this.gradientSumArray();
            DoubleRef sum = DoubleRef.create((double)0.0);
            features.foreachNonZero((Function2)(JFunction2.mcVID.sp & java.io.Serializable & Serializable)(index, value) -> {
                block0: {
                    if (localFeaturesStd[index] == 0.0) break block0;
                    sum$1.elem += localCoefficients[index] * (value / localFeaturesStd[index]);
                }
            });
            if (this.fitIntercept) {
                sum.elem += this.intercept();
            }
            if (package$.MODULE$.abs(linearLoss = label - (margin = sum.elem)) <= this.sigma() * this.epsilon) {
                this.lossSum_$eq(this.lossSum() + 0.5 * weight * (this.sigma() + package$.MODULE$.pow(linearLoss, 2.0) / this.sigma()));
                double linearLossDivSigma = linearLoss / this.sigma();
                features.foreachNonZero((Function2)(JFunction2.mcVID.sp & java.io.Serializable & Serializable)(index, value) -> {
                    block0: {
                        if (localFeaturesStd[index] == 0.0) break block0;
                        localGradientSumArray$1[index] = localGradientSumArray[index] + -1.0 * weight * linearLossDivSigma * (value / localFeaturesStd[index]);
                    }
                });
                if (this.fitIntercept) {
                    int n = this.dim() - 2;
                    localGradientSumArray[n] = localGradientSumArray[n] + -1.0 * weight * linearLossDivSigma;
                }
                int n = this.dim() - 1;
                localGradientSumArray[n] = localGradientSumArray[n] + 0.5 * weight * (1.0 - package$.MODULE$.pow(linearLossDivSigma, 2.0));
            } else {
                double sign = linearLoss >= 0.0 ? -1.0 : 1.0;
                this.lossSum_$eq(this.lossSum() + 0.5 * weight * (this.sigma() + 2.0 * this.epsilon * package$.MODULE$.abs(linearLoss) - this.sigma() * this.epsilon * this.epsilon));
                features.foreachNonZero((Function2)(JFunction2.mcVID.sp & java.io.Serializable & Serializable)(index, value) -> {
                    block0: {
                        if (localFeaturesStd[index] == 0.0) break block0;
                        localGradientSumArray$1[index] = localGradientSumArray[index] + weight * sign * $this.epsilon * (value / localFeaturesStd[index]);
                    }
                });
                if (this.fitIntercept) {
                    int n = this.dim() - 2;
                    localGradientSumArray[n] = localGradientSumArray[n] + weight * sign * this.epsilon;
                }
                int n = this.dim() - 1;
                localGradientSumArray[n] = localGradientSumArray[n] + 0.5 * weight * (1.0 - this.epsilon * this.epsilon);
            }
        } else {
            throw new MatchError((Object)instance2);
        }
        this.weightSum_$eq(this.weightSum() + weight);
        HuberAggregator huberAggregator = this;
        return huberAggregator;
    }

    public HuberAggregator(boolean fitIntercept, double epsilon, Broadcast<double[]> bcFeaturesStd, Broadcast<Vector> bcParameters) {
        this.fitIntercept = fitIntercept;
        this.epsilon = epsilon;
        this.bcFeaturesStd = bcFeaturesStd;
        this.bcParameters = bcParameters;
        DifferentiableLossAggregator.$init$(this);
        this.dim = ((Vector)bcParameters.value()).size();
        this.numFeatures = fitIntercept ? this.dim() - 2 : this.dim() - 1;
        this.sigma = ((Vector)bcParameters.value()).apply(this.dim() - 1);
        this.intercept = fitIntercept ? ((Vector)bcParameters.value()).apply(this.dim() - 2) : 0.0;
    }
}

