/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import java.io.IOException;
import java.io.Serializable;
import org.apache.spark.ml.ann.FeedForwardTopology;
import org.apache.spark.ml.ann.FeedForwardTopology$;
import org.apache.spark.ml.ann.FeedForwardTrainer;
import org.apache.spark.ml.ann.TopologyModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier$;
import org.apache.spark.ml.classification.MultilayerPerceptronParams;
import org.apache.spark.ml.classification.ProbabilisticClassifier;
import org.apache.spark.ml.feature.OneHotEncoderModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntArrayParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.LongParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasBlockSize;
import org.apache.spark.ml.param.shared.HasMaxIter;
import org.apache.spark.ml.param.shared.HasSeed;
import org.apache.spark.ml.param.shared.HasSolver;
import org.apache.spark.ml.param.shared.HasStepSize;
import org.apache.spark.ml.param.shared.HasTol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.Instrumentation$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.mllib.optimization.Optimizer;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.SeqLike;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u0005]f\u0001\u0002\u000e\u001c\u0001\u0019B\u0001B\u0010\u0001\u0003\u0006\u0004%\te\u0010\u0005\t-\u0002\u0011\t\u0011)A\u0005\u0001\")\u0001\f\u0001C\u00013\")\u0001\f\u0001C\u0001;\")q\f\u0001C\u0001A\")A\u000e\u0001C\u0001[\")\u0001\u000f\u0001C\u0001c\")a\u000f\u0001C\u0001o\")!\u0010\u0001C\u0001w\"9\u00111\u0001\u0001\u0005\u0002\u0005\u0015\u0001bBA\t\u0001\u0011\u0005\u00111\u0003\u0005\b\u00033\u0001A\u0011AA\u000e\u0011\u001d\t\t\u0003\u0001C!\u0003GAq!a\u000e\u0001\t#\nIdB\u0004\u0002fmA\t!a\u001a\u0007\riY\u0002\u0012AA5\u0011\u0019A\u0006\u0003\"\u0001\u0002~!Q\u0011q\u0010\tC\u0002\u0013\u00051$!!\t\u0011\u0005E\u0005\u0003)A\u0005\u0003\u0007C!\"a%\u0011\u0005\u0004%\taGAA\u0011!\t)\n\u0005Q\u0001\n\u0005\r\u0005BCAL!\t\u0007I\u0011A\u000e\u0002\u001a\"A\u0011Q\u0014\t!\u0002\u0013\tY\nC\u0004\u0002 B!\t%!)\t\u0013\u0005%\u0006#!A\u0005\n\u0005-&AH'vYRLG.Y=feB+'oY3qiJ|gn\u00117bgNLg-[3s\u0015\taR$\u0001\bdY\u0006\u001c8/\u001b4jG\u0006$\u0018n\u001c8\u000b\u0005yy\u0012AA7m\u0015\t\u0001\u0013%A\u0003ta\u0006\u00148N\u0003\u0002#G\u00051\u0011\r]1dQ\u0016T\u0011\u0001J\u0001\u0004_J<7\u0001A\n\u0005\u0001\u001d*\u0004\bE\u0003)S-\n$'D\u0001\u001c\u0013\tQ3DA\fQe>\u0014\u0017MY5mSN$\u0018nY\"mCN\u001c\u0018NZ5feB\u0011AfL\u0007\u0002[)\u0011a&H\u0001\u0007Y&t\u0017\r\\4\n\u0005Aj#A\u0002,fGR|'\u000f\u0005\u0002)\u0001A\u0011\u0001fM\u0005\u0003im\u0011q%T;mi&d\u0017-_3s!\u0016\u00148-\u001a9ue>t7\t\\1tg&4\u0017nY1uS>tWj\u001c3fYB\u0011\u0001FN\u0005\u0003om\u0011!$T;mi&d\u0017-_3s!\u0016\u00148-\u001a9ue>t\u0007+\u0019:b[N\u0004\"!\u000f\u001f\u000e\u0003iR!aO\u000f\u0002\tU$\u0018\u000e\\\u0005\u0003{i\u0012Q\u0003R3gCVdG\u000fU1sC6\u001cxK]5uC\ndW-A\u0002vS\u0012,\u0012\u0001\u0011\t\u0003\u0003*s!A\u0011%\u0011\u0005\r3U\"\u0001#\u000b\u0005\u0015+\u0013A\u0002\u001fs_>$hHC\u0001H\u0003\u0015\u00198-\u00197b\u0013\tIe)\u0001\u0004Qe\u0016$WMZ\u0005\u0003\u00172\u0013aa\u0015;sS:<'BA%GQ\r\ta\n\u0016\t\u0003\u001fJk\u0011\u0001\u0015\u0006\u0003#~\t!\"\u00198o_R\fG/[8o\u0013\t\u0019\u0006KA\u0003TS:\u001cW-I\u0001V\u0003\u0015\td&\u000e\u00181\u0003\u0011)\u0018\u000e\u001a\u0011)\u0007\tqE+\u0001\u0004=S:LGO\u0010\u000b\u0003ciCQAP\u0002A\u0002\u0001C3A\u0017(UQ\r\u0019a\n\u0016\u000b\u0002c!\u001aAA\u0014+\u0002\u0013M,G\u000fT1zKJ\u001cHCA1c\u001b\u0005\u0001\u0001\"B2\u0006\u0001\u0004!\u0017!\u0002<bYV,\u0007cA3gQ6\ta)\u0003\u0002h\r\n)\u0011I\u001d:bsB\u0011Q-[\u0005\u0003U\u001a\u00131!\u00138uQ\r)a\nV\u0001\rg\u0016$(\t\\8dWNK'0\u001a\u000b\u0003C:DQa\u0019\u0004A\u0002!D3A\u0002(U\u0003%\u0019X\r^*pYZ,'\u000f\u0006\u0002be\")1m\u0002a\u0001\u0001\"\u001aqA\u0014;\"\u0003U\fQA\r\u00181]A\n!b]3u\u001b\u0006D\u0018\n^3s)\t\t\u0007\u0010C\u0003d\u0011\u0001\u0007\u0001\u000eK\u0002\t\u001dR\u000baa]3u)>dGCA1}\u0011\u0015\u0019\u0017\u00021\u0001~!\t)g0\u0003\u0002\u0000\r\n1Ai\\;cY\u0016D3!\u0003(U\u0003\u001d\u0019X\r^*fK\u0012$2!YA\u0004\u0011\u0019\u0019'\u00021\u0001\u0002\nA\u0019Q-a\u0003\n\u0007\u00055aI\u0001\u0003M_:<\u0007f\u0001\u0006O)\u0006\t2/\u001a;J]&$\u0018.\u00197XK&<\u0007\u000e^:\u0015\u0007\u0005\f)\u0002C\u0003d\u0017\u0001\u00071\u0006K\u0002\f\u001dR\f1b]3u'R,\u0007oU5{KR\u0019\u0011-!\b\t\u000b\rd\u0001\u0019A?)\u00071qE/\u0001\u0003d_BLHcA\u0019\u0002&!9\u0011qE\u0007A\u0002\u0005%\u0012!B3yiJ\f\u0007\u0003BA\u0016\u0003ci!!!\f\u000b\u0007\u0005=R$A\u0003qCJ\fW.\u0003\u0003\u00024\u00055\"\u0001\u0003)be\u0006lW*\u00199)\u00075qE+A\u0003ue\u0006Lg\u000eF\u00023\u0003wAq!!\u0010\u000f\u0001\u0004\ty$A\u0004eCR\f7/\u001a;1\t\u0005\u0005\u0013\u0011\u000b\t\u0007\u0003\u0007\nI%!\u0014\u000e\u0005\u0005\u0015#bAA$?\u0005\u00191/\u001d7\n\t\u0005-\u0013Q\t\u0002\b\t\u0006$\u0018m]3u!\u0011\ty%!\u0015\r\u0001\u0011a\u00111KA\u001e\u0003\u0003\u0005\tQ!\u0001\u0002V\t\u0019q\fJ\u0019\u0012\t\u0005]\u0013Q\f\t\u0004K\u0006e\u0013bAA.\r\n9aj\u001c;iS:<\u0007cA3\u0002`%\u0019\u0011\u0011\r$\u0003\u0007\u0005s\u0017\u0010K\u0002\u0001\u001dR\u000ba$T;mi&d\u0017-_3s!\u0016\u00148-\u001a9ue>t7\t\\1tg&4\u0017.\u001a:\u0011\u0005!\u00022c\u0002\t\u0002l\u0005E\u0014q\u000f\t\u0004K\u00065\u0014bAA8\r\n1\u0011I\\=SK\u001a\u0004B!OA:c%\u0019\u0011Q\u000f\u001e\u0003+\u0011+g-Y;miB\u000b'/Y7t%\u0016\fG-\u00192mKB\u0019Q-!\u001f\n\u0007\u0005mdI\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0006\u0002\u0002h\u0005)AJ\u0011$H'V\u0011\u00111\u0011\t\u0005\u0003\u000b\u000by)\u0004\u0002\u0002\b*!\u0011\u0011RAF\u0003\u0011a\u0017M\\4\u000b\u0005\u00055\u0015\u0001\u00026bm\u0006L1aSAD\u0003\u0019a%IR$TA\u0005\u0011q\tR\u0001\u0004\u000f\u0012\u0003\u0013\u0001E:vaB|'\u000f^3e'>dg/\u001a:t+\t\tY\n\u0005\u0003fM\u0006\r\u0015!E:vaB|'\u000f^3e'>dg/\u001a:tA\u0005!An\\1e)\r\t\u00141\u0015\u0005\u0007\u0003KC\u0002\u0019\u0001!\u0002\tA\fG\u000f\u001b\u0015\u000419#\u0018a\u0003:fC\u0012\u0014Vm]8mm\u0016$\"!!,\u0011\t\u0005\u0015\u0015qV\u0005\u0005\u0003c\u000b9I\u0001\u0004PE*,7\r\u001e\u0015\u0004!9#\bfA\bOi\u0002")
public class MultilayerPerceptronClassifier
extends ProbabilisticClassifier<Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel>
implements MultilayerPerceptronParams,
DefaultParamsWritable {
    private final String uid;
    private final IntArrayParam layers;
    private final Param<String> solver;
    private final Param<Vector> initialWeights;
    private final IntParam blockSize;
    private final DoubleParam stepSize;
    private final DoubleParam tol;
    private final IntParam maxIter;
    private final LongParam seed;

    public static MultilayerPerceptronClassifier load(String string) {
        return MultilayerPerceptronClassifier$.MODULE$.load(string);
    }

    public static MLReader<MultilayerPerceptronClassifier> read() {
        return MultilayerPerceptronClassifier$.MODULE$.read();
    }

    @Override
    public MLWriter write() {
        return DefaultParamsWritable.write$(this);
    }

    @Override
    public void save(String path) throws IOException {
        MLWritable.save$(this, path);
    }

    @Override
    public final int[] getLayers() {
        return MultilayerPerceptronParams.getLayers$(this);
    }

    @Override
    public final Vector getInitialWeights() {
        return MultilayerPerceptronParams.getInitialWeights$(this);
    }

    @Override
    public final int getBlockSize() {
        return HasBlockSize.getBlockSize$(this);
    }

    @Override
    public final String getSolver() {
        return HasSolver.getSolver$(this);
    }

    @Override
    public final double getStepSize() {
        return HasStepSize.getStepSize$(this);
    }

    @Override
    public final double getTol() {
        return HasTol.getTol$(this);
    }

    @Override
    public final int getMaxIter() {
        return HasMaxIter.getMaxIter$(this);
    }

    @Override
    public final long getSeed() {
        return HasSeed.getSeed$(this);
    }

    @Override
    public final IntArrayParam layers() {
        return this.layers;
    }

    @Override
    public final Param<String> solver() {
        return this.solver;
    }

    @Override
    public final Param<Vector> initialWeights() {
        return this.initialWeights;
    }

    @Override
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$layers_$eq(IntArrayParam x$1) {
        this.layers = x$1;
    }

    @Override
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$solver_$eq(Param<String> x$1) {
        this.solver = x$1;
    }

    @Override
    public final void org$apache$spark$ml$classification$MultilayerPerceptronParams$_setter_$initialWeights_$eq(Param<Vector> x$1) {
        this.initialWeights = x$1;
    }

    @Override
    public final IntParam blockSize() {
        return this.blockSize;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasBlockSize$_setter_$blockSize_$eq(IntParam x$1) {
        this.blockSize = x$1;
    }

    @Override
    public void org$apache$spark$ml$param$shared$HasSolver$_setter_$solver_$eq(Param<String> x$1) {
    }

    @Override
    public DoubleParam stepSize() {
        return this.stepSize;
    }

    @Override
    public void org$apache$spark$ml$param$shared$HasStepSize$_setter_$stepSize_$eq(DoubleParam x$1) {
        this.stepSize = x$1;
    }

    @Override
    public final DoubleParam tol() {
        return this.tol;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam x$1) {
        this.tol = x$1;
    }

    @Override
    public final IntParam maxIter() {
        return this.maxIter;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam x$1) {
        this.maxIter = x$1;
    }

    @Override
    public final LongParam seed() {
        return this.seed;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasSeed$_setter_$seed_$eq(LongParam x$1) {
        this.seed = x$1;
    }

    @Override
    public String uid() {
        return this.uid;
    }

    public MultilayerPerceptronClassifier setLayers(int[] value) {
        return (MultilayerPerceptronClassifier)this.set(this.layers(), value);
    }

    public MultilayerPerceptronClassifier setBlockSize(int value) {
        return (MultilayerPerceptronClassifier)this.set(this.blockSize(), BoxesRunTime.boxToInteger((int)value));
    }

    public MultilayerPerceptronClassifier setSolver(String value) {
        return (MultilayerPerceptronClassifier)this.set(this.solver(), value);
    }

    public MultilayerPerceptronClassifier setMaxIter(int value) {
        return (MultilayerPerceptronClassifier)this.set(this.maxIter(), BoxesRunTime.boxToInteger((int)value));
    }

    public MultilayerPerceptronClassifier setTol(double value) {
        return (MultilayerPerceptronClassifier)this.set(this.tol(), BoxesRunTime.boxToDouble((double)value));
    }

    public MultilayerPerceptronClassifier setSeed(long value) {
        return (MultilayerPerceptronClassifier)this.set(this.seed(), BoxesRunTime.boxToLong((long)value));
    }

    public MultilayerPerceptronClassifier setInitialWeights(Vector value) {
        return (MultilayerPerceptronClassifier)this.set(this.initialWeights(), value);
    }

    public MultilayerPerceptronClassifier setStepSize(double value) {
        return (MultilayerPerceptronClassifier)this.set(this.stepSize(), BoxesRunTime.boxToDouble((double)value));
    }

    @Override
    public MultilayerPerceptronClassifier copy(ParamMap extra) {
        return (MultilayerPerceptronClassifier)this.defaultCopy(extra);
    }

    @Override
    public MultilayerPerceptronClassificationModel train(Dataset<?> dataset) {
        return (MultilayerPerceptronClassificationModel)Instrumentation$.MODULE$.instrumented((Function1 & Serializable & scala.Serializable)instr -> {
            Optimizer optimizer;
            instr.logPipelineStage(this);
            instr.logDataset(dataset);
            instr.logParams(this, (Seq<Param<?>>)Predef$.MODULE$.wrapRefArray((Object[])new Param[]{this.labelCol(), this.featuresCol(), this.predictionCol(), this.rawPredictionCol(), this.layers(), this.maxIter(), this.tol(), this.blockSize(), this.solver(), this.stepSize(), this.seed()}));
            int[] myLayers = this.$(this.layers());
            int labels = BoxesRunTime.unboxToInt((Object)new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(myLayers)).last());
            instr.logNumClasses(labels);
            instr.logNumFeatures(BoxesRunTime.unboxToInt((Object)new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(myLayers)).head()));
            String encodedLabelCol = new StringBuilder(8).append("_encoded").append((Object)this.$(this.labelCol())).toString();
            OneHotEncoderModel encodeModel = new OneHotEncoderModel(this.uid(), new int[]{labels}).setInputCols((String[])((Object[])new String[]{this.$(this.labelCol())})).setOutputCols((String[])((Object[])new String[]{encodedLabelCol})).setDropLast(false);
            Dataset<Row> encodedDataset = encodeModel.transform(dataset);
            RDD data = encodedDataset.select(this.$(this.featuresCol()), (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{encodedLabelCol})).rdd().map((Function1 & Serializable & scala.Serializable)x0$1 -> {
                Vector vector;
                Object encodedLabel;
                block3: {
                    Row row;
                    block2: {
                        row = x0$1;
                        Some some = Row$.MODULE$.unapplySeq(row);
                        if (some.isEmpty() || some.get() == null || ((SeqLike)some.get()).lengthCompare(2) != 0) break block2;
                        Object features = ((SeqLike)some.get()).apply(0);
                        encodedLabel = ((SeqLike)some.get()).apply(1);
                        if (!(features instanceof Vector)) break block2;
                        vector = (Vector)features;
                        if (encodedLabel instanceof Vector) break block3;
                    }
                    throw new MatchError((Object)row);
                }
                Vector vector2 = (Vector)encodedLabel;
                Tuple2 tuple2 = new Tuple2((Object)vector, (Object)vector2);
                return tuple2;
            }, ClassTag$.MODULE$.apply(Tuple2.class));
            FeedForwardTopology topology = FeedForwardTopology$.MODULE$.multiLayerPerceptron(myLayers, true);
            FeedForwardTrainer trainer = new FeedForwardTrainer(topology, myLayers[0], BoxesRunTime.unboxToInt((Object)new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(myLayers)).last()));
            FeedForwardTrainer feedForwardTrainer = this.isDefined(this.initialWeights()) ? trainer.setWeights(this.$(this.initialWeights())) : trainer.setSeed(BoxesRunTime.unboxToLong((Object)this.$(this.seed())));
            String string = this.$(this.solver());
            String string2 = MultilayerPerceptronClassifier$.MODULE$.LBFGS();
            if (!(string != null ? !string.equals(string2) : string2 != null)) {
                optimizer = trainer.LBFGSOptimizer().setConvergenceTol(BoxesRunTime.unboxToDouble((Object)this.$(this.tol()))).setNumIterations(BoxesRunTime.unboxToInt((Object)this.$(this.maxIter())));
            } else {
                String string3 = this.$(this.solver());
                String string4 = MultilayerPerceptronClassifier$.MODULE$.GD();
                if (!(string3 != null ? !string3.equals(string4) : string4 != null)) {
                    optimizer = trainer.SGDOptimizer().setNumIterations(BoxesRunTime.unboxToInt((Object)this.$(this.maxIter()))).setConvergenceTol(BoxesRunTime.unboxToDouble((Object)this.$(this.tol()))).setStepSize(BoxesRunTime.unboxToDouble((Object)this.$(this.stepSize())));
                } else {
                    throw new IllegalArgumentException(new StringBuilder(63).append("The solver ").append(this.solver()).append(" is not supported by MultilayerPerceptronClassifier.").toString());
                }
            }
            trainer.setStackSize(BoxesRunTime.unboxToInt((Object)this.$(this.blockSize())));
            TopologyModel mlpModel = trainer.train((RDD<Tuple2<Vector, Vector>>)data);
            return new MultilayerPerceptronClassificationModel(this.uid(), mlpModel.weights());
        });
    }

    public MultilayerPerceptronClassifier(String uid) {
        this.uid = uid;
        HasSeed.$init$(this);
        HasMaxIter.$init$(this);
        HasTol.$init$(this);
        HasStepSize.$init$(this);
        HasSolver.$init$(this);
        HasBlockSize.$init$(this);
        MultilayerPerceptronParams.$init$(this);
        MLWritable.$init$(this);
        DefaultParamsWritable.$init$(this);
    }

    public MultilayerPerceptronClassifier() {
        this(Identifiable$.MODULE$.randomUID("mlpc"));
    }
}

