/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.tuning;

import com.github.fommil.netlib.F2jBLAS;
import org.apache.spark.Logging;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.Params$class;
import org.apache.spark.ml.tuning.CrossValidator$;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.CrossValidatorParams;
import org.apache.spark.ml.tuning.CrossValidatorParams$class;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.mllib.util.MLUtils$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;

@Experimental
@ScalaSignature(bytes="\u0006\u0001\u0005Eb\u0001B\u0001\u0003\u00015\u0011ab\u0011:pgN4\u0016\r\\5eCR|'O\u0003\u0002\u0004\t\u00051A/\u001e8j]\u001eT!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u0015\u0019\b/\u0019:l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M!\u0001A\u0004\f\u001a!\ry\u0001CE\u0007\u0002\t%\u0011\u0011\u0003\u0002\u0002\n\u000bN$\u0018.\\1u_J\u0004\"a\u0005\u000b\u000e\u0003\tI!!\u0006\u0002\u0003'\r\u0013xn]:WC2LG-\u0019;pe6{G-\u001a7\u0011\u0005M9\u0012B\u0001\r\u0003\u0005Q\u0019%o\\:t-\u0006d\u0017\u000eZ1u_J\u0004\u0016M]1ngB\u0011!dG\u0007\u0002\r%\u0011AD\u0002\u0002\b\u0019><w-\u001b8h\u0011!q\u0002A!b\u0001\n\u0003z\u0012aA;jIV\t\u0001\u0005\u0005\u0002\"O9\u0011!%J\u0007\u0002G)\tA%A\u0003tG\u0006d\u0017-\u0003\u0002'G\u00051\u0001K]3eK\u001aL!\u0001K\u0015\u0003\rM#(/\u001b8h\u0015\t13\u0005\u0003\u0005,\u0001\t\u0005\t\u0015!\u0003!\u0003\u0011)\u0018\u000e\u001a\u0011\t\u000b5\u0002A\u0011\u0001\u0018\u0002\rqJg.\u001b;?)\ty\u0003\u0007\u0005\u0002\u0014\u0001!)a\u0004\fa\u0001A!)Q\u0006\u0001C\u0001eQ\tq\u0006C\u00045\u0001\t\u0007I\u0011B\u001b\u0002\u000f\u0019\u0014$N\u0011'B'V\ta\u0007\u0005\u00028\u00016\t\u0001H\u0003\u0002:u\u00051a.\u001a;mS\nT!a\u000f\u001f\u0002\r\u0019|W.\\5m\u0015\tid(\u0001\u0004hSRDWO\u0019\u0006\u0002\u007f\u0005\u00191m\\7\n\u0005\u0005C$a\u0002$3U\nc\u0015i\u0015\u0005\u0007\u0007\u0002\u0001\u000b\u0011\u0002\u001c\u0002\u0011\u0019\u0014$N\u0011'B'\u0002BQ!\u0012\u0001\u0005\u0002\u0019\u000bAb]3u\u000bN$\u0018.\\1u_J$\"a\u0012%\u000e\u0003\u0001AQ!\u0013#A\u0002)\u000bQA^1mk\u0016\u0004$a\u0013(\u0011\u0007=\u0001B\n\u0005\u0002N\u001d2\u0001A!C(I\u0003\u0003\u0005\tQ!\u0001Q\u0005\ryFeM\t\u0003#R\u0003\"A\t*\n\u0005M\u001b#a\u0002(pi\"Lgn\u001a\t\u0003EUK!AV\u0012\u0003\u0007\u0005s\u0017\u0010C\u0003Y\u0001\u0011\u0005\u0011,A\u000btKR,5\u000f^5nCR|'\u000fU1sC6l\u0015\r]:\u0015\u0005\u001dS\u0006\"B%X\u0001\u0004Y\u0006c\u0001\u0012]=&\u0011Ql\t\u0002\u0006\u0003J\u0014\u0018-\u001f\t\u0003?\nl\u0011\u0001\u0019\u0006\u0003C\u0012\tQ\u0001]1sC6L!a\u00191\u0003\u0011A\u000b'/Y7NCBDQ!\u001a\u0001\u0005\u0002\u0019\fAb]3u\u000bZ\fG.^1u_J$\"aR4\t\u000b%#\u0007\u0019\u00015\u0011\u0005%dW\"\u00016\u000b\u0005-$\u0011AC3wC2,\u0018\r^5p]&\u0011QN\u001b\u0002\n\u000bZ\fG.^1u_JDQa\u001c\u0001\u0005\u0002A\f1b]3u\u001dVlgi\u001c7egR\u0011q)\u001d\u0005\u0006\u0013:\u0004\rA\u001d\t\u0003EML!\u0001^\u0012\u0003\u0007%sG\u000fC\u0003w\u0001\u0011\u0005s/A\u0002gSR$\"A\u0005=\t\u000be,\b\u0019\u0001>\u0002\u000f\u0011\fG/Y:fiB\u00111P`\u0007\u0002y*\u0011QPB\u0001\u0004gFd\u0017BA@}\u0005%!\u0015\r^1Ge\u0006lW\rC\u0004\u0002\u0004\u0001!\t%!\u0002\u0002\u001fQ\u0014\u0018M\\:g_Jl7k\u00195f[\u0006$B!a\u0002\u0002\u0014A!\u0011\u0011BA\b\u001b\t\tYAC\u0002\u0002\u000eq\fQ\u0001^=qKNLA!!\u0005\u0002\f\tQ1\u000b\u001e:vGR$\u0016\u0010]3\t\u0011\u0005U\u0011\u0011\u0001a\u0001\u0003\u000f\taa]2iK6\f\u0007bBA\r\u0001\u0011\u0005\u00131D\u0001\u000fm\u0006d\u0017\u000eZ1uKB\u000b'/Y7t)\t\ti\u0002E\u0002#\u0003?I1!!\t$\u0005\u0011)f.\u001b;)\u0007\u0001\t)\u0003\u0005\u0003\u0002(\u00055RBAA\u0015\u0015\r\tYCB\u0001\u000bC:tw\u000e^1uS>t\u0017\u0002BA\u0018\u0003S\u0011A\"\u0012=qKJLW.\u001a8uC2\u0004")
public class CrossValidator
extends Estimator<CrossValidatorModel>
implements CrossValidatorParams,
Logging {
    private final String uid;
    private final F2jBLAS f2jBLAS;
    private final Param<Estimator<?>> estimator;
    private final Param<ParamMap[]> estimatorParamMaps;
    private final Param<Evaluator> evaluator;
    private final IntParam numFolds;

    @Override
    public Param<Estimator<?>> estimator() {
        return this.estimator;
    }

    @Override
    public Param<ParamMap[]> estimatorParamMaps() {
        return this.estimatorParamMaps;
    }

    @Override
    public Param<Evaluator> evaluator() {
        return this.evaluator;
    }

    @Override
    public IntParam numFolds() {
        return this.numFolds;
    }

    @Override
    public void org$apache$spark$ml$tuning$CrossValidatorParams$_setter_$estimator_$eq(Param x$1) {
        this.estimator = x$1;
    }

    @Override
    public void org$apache$spark$ml$tuning$CrossValidatorParams$_setter_$estimatorParamMaps_$eq(Param x$1) {
        this.estimatorParamMaps = x$1;
    }

    @Override
    public void org$apache$spark$ml$tuning$CrossValidatorParams$_setter_$evaluator_$eq(Param x$1) {
        this.evaluator = x$1;
    }

    @Override
    public void org$apache$spark$ml$tuning$CrossValidatorParams$_setter_$numFolds_$eq(IntParam x$1) {
        this.numFolds = x$1;
    }

    @Override
    public Estimator<?> getEstimator() {
        return CrossValidatorParams$class.getEstimator(this);
    }

    @Override
    public ParamMap[] getEstimatorParamMaps() {
        return CrossValidatorParams$class.getEstimatorParamMaps(this);
    }

    @Override
    public Evaluator getEvaluator() {
        return CrossValidatorParams$class.getEvaluator(this);
    }

    @Override
    public int getNumFolds() {
        return CrossValidatorParams$class.getNumFolds(this);
    }

    @Override
    public String uid() {
        return this.uid;
    }

    private F2jBLAS f2jBLAS() {
        return this.f2jBLAS;
    }

    public CrossValidator setEstimator(Estimator<?> value) {
        return (CrossValidator)this.set(this.estimator(), value);
    }

    public CrossValidator setEstimatorParamMaps(ParamMap[] value) {
        return (CrossValidator)this.set(this.estimatorParamMaps(), value);
    }

    public CrossValidator setEvaluator(Evaluator value) {
        return (CrossValidator)this.set(this.evaluator(), value);
    }

    public CrossValidator setNumFolds(int value) {
        return (CrossValidator)this.set(this.numFolds(), BoxesRunTime.boxToInteger((int)value));
    }

    @Override
    public CrossValidatorModel fit(DataFrame dataset) {
        StructType schema = dataset.schema();
        this.transformSchema(schema, true);
        SQLContext sqlCtx = dataset.sqlContext();
        Estimator<?> est = this.$(this.estimator());
        Evaluator eval = this.$(this.evaluator());
        ParamMap[] epm = this.$(this.estimatorParamMaps());
        int numModels = epm.length;
        double[] metrics = new double[epm.length];
        Tuple2<RDD<T>, RDD<T>>[] splits = MLUtils$.MODULE$.kFold(dataset.rdd(), BoxesRunTime.unboxToInt((Object)this.$(this.numFolds())), 0, ClassTag$.MODULE$.apply(Row.class));
        Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])splits).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach((Function1)new Serializable(this, schema, sqlCtx, est, eval, epm, numModels, metrics){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ CrossValidator $outer;
            private final StructType schema$1;
            private final SQLContext sqlCtx$1;
            private final Estimator est$1;
            private final Evaluator eval$1;
            public final ParamMap[] epm$1;
            private final int numModels$1;
            private final double[] metrics$1;

            public final DataFrame apply(Tuple2<Tuple2<RDD<Row>, RDD<Row>>, Object> x0$1) {
                Tuple2<Tuple2<RDD<Row>, RDD<Row>>, Object> tuple2 = x0$1;
                if (tuple2 != null) {
                    Tuple2 tuple22 = (Tuple2)tuple2._1();
                    int splitIndex = tuple2._2$mcI$sp();
                    if (tuple22 != null) {
                        RDD training = (RDD)tuple22._1();
                        RDD validation = (RDD)tuple22._2();
                        DataFrame trainingDataset = this.sqlCtx$1.createDataFrame(training, this.schema$1).cache();
                        DataFrame validationDataset = this.sqlCtx$1.createDataFrame(validation, this.schema$1).cache();
                        this.$outer.logDebug((Function0<String>)new Serializable(this, splitIndex){
                            public static final long serialVersionUID = 0L;
                            private final int splitIndex$1;

                            public final String apply() {
                                return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Train split ", " with multiple sets of parameters."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.splitIndex$1)}));
                            }
                            {
                                this.splitIndex$1 = splitIndex$1;
                            }
                        });
                        Seq<M> models = this.est$1.fit(trainingDataset, this.epm$1);
                        trainingDataset.unpersist();
                        IntRef i = new IntRef(0);
                        while (i.elem < this.numModels$1) {
                            double metric = this.eval$1.evaluate(((Transformer)models.apply(i.elem)).transform(validationDataset, this.epm$1[i.elem]));
                            this.$outer.logDebug((Function0<String>)new Serializable(this, i, metric){
                                public static final long serialVersionUID = 0L;
                                private final /* synthetic */ $anonfun$fit$1 $outer;
                                private final IntRef i$1;
                                private final double metric$1;

                                public final String apply() {
                                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Got metric ", " for model trained with ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.metric$1), this.$outer.epm$1[this.i$1.elem]}));
                                }
                                {
                                    if ($outer == null) {
                                        throw new NullPointerException();
                                    }
                                    this.$outer = $outer;
                                    this.i$1 = i$1;
                                    this.metric$1 = metric$1;
                                }
                            });
                            int n = i.elem++;
                            this.metrics$1[n] = this.metrics$1[n] + metric;
                        }
                        DataFrame dataFrame = validationDataset.unpersist();
                        return dataFrame;
                    }
                }
                throw new MatchError(tuple2);
            }
            {
                if ($outer == null) {
                    throw new NullPointerException();
                }
                this.$outer = $outer;
                this.schema$1 = schema$1;
                this.sqlCtx$1 = sqlCtx$1;
                this.est$1 = est$1;
                this.eval$1 = eval$1;
                this.epm$1 = epm$1;
                this.numModels$1 = numModels$1;
                this.metrics$1 = metrics$1;
            }
        });
        this.f2jBLAS().dscal(numModels, 1.0 / (double)BoxesRunTime.unboxToInt((Object)this.$(this.numFolds())), metrics, 1);
        this.logInfo((Function0<String>)new Serializable(this, metrics){
            public static final long serialVersionUID = 0L;
            private final double[] metrics$1;

            public final String apply() {
                return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Average cross-validation metrics: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{Predef$.MODULE$.doubleArrayOps(this.metrics$1).toSeq()}));
            }
            {
                this.metrics$1 = metrics$1;
            }
        });
        Tuple2 tuple2 = (Tuple2)Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.doubleArrayOps(metrics).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).maxBy((Function1)new $anonfun$1(this), (Ordering)Ordering.Double$.MODULE$);
        if (tuple2 != null) {
            Tuple2.mcDI.sp sp2;
            double bestMetric = tuple2._1$mcD$sp();
            int bestIndex = tuple2._2$mcI$sp();
            Tuple2.mcDI.sp sp3 = sp2 = new Tuple2.mcDI.sp(bestMetric, bestIndex);
            double bestMetric2 = sp3._1$mcD$sp();
            int bestIndex2 = sp3._2$mcI$sp();
            this.logInfo((Function0<String>)new Serializable(this, epm, bestIndex2){
                public static final long serialVersionUID = 0L;
                private final ParamMap[] epm$1;
                private final int bestIndex$1;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Best set of parameters:\\n", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{this.epm$1[this.bestIndex$1]}));
                }
                {
                    this.epm$1 = epm$1;
                    this.bestIndex$1 = bestIndex$1;
                }
            });
            this.logInfo((Function0<String>)new Serializable(this, bestMetric2){
                public static final long serialVersionUID = 0L;
                private final double bestMetric$1;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Best cross-validation metric: ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.bestMetric$1)}));
                }
                {
                    this.bestMetric$1 = bestMetric$1;
                }
            });
            Object bestModel = est.fit(dataset, epm[bestIndex2]);
            return this.copyValues(new CrossValidatorModel(this.uid(), (Model<?>)bestModel).setParent(this), this.copyValues$default$2());
        }
        throw new MatchError((Object)tuple2);
    }

    @Override
    public StructType transformSchema(StructType schema) {
        return ((PipelineStage)this.$(this.estimator())).transformSchema(schema);
    }

    @Override
    public void validateParams() {
        Params$class.validateParams(this);
        Estimator<?> est = this.$(this.estimator());
        Predef$.MODULE$.refArrayOps((Object[])this.$(this.estimatorParamMaps())).foreach((Function1)new Serializable(this, est){
            public static final long serialVersionUID = 0L;
            private final Estimator est$2;

            public final void apply(ParamMap paramMap) {
                ((PipelineStage)this.est$2.copy(paramMap)).validateParams();
            }
            {
                this.est$2 = est$2;
            }
        });
    }

    public CrossValidator(String uid) {
        this.uid = uid;
        CrossValidatorParams$class.$init$(this);
        this.f2jBLAS = new F2jBLAS();
    }

    public CrossValidator() {
        this(Identifiable$.MODULE$.randomUID("cv"));
    }
}

