/*
 * Decompiled with CFR 0.152.
 */
package edu.usc.irds.agepredictor.authorage;

import edu.usc.irds.agepredictor.spark.authorage.AgePredictModel;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import opennlp.tools.authorage.AgeClassifyME;
import opennlp.tools.authorage.AgeClassifyModel;
import opennlp.tools.util.InvalidFormatException;
import opennlp.tools.util.featuregen.FeatureGenerator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class AgePredicterLocal {
    private SparkSession spark = SparkSession.builder().master("local").appName("AgePredict").getOrCreate();
    private AgeClassifyModel classifyModel;
    private AgeClassifyME classify;
    private AgePredictModel model;

    public AgePredicterLocal() throws InvalidFormatException, IOException {
        this("./model/classify-bigram.bin", "./model/regression-global.bin");
    }

    public AgePredicterLocal(String pathToClassifyModel, String pathToRegressionModel) throws InvalidFormatException, IOException {
        this.classifyModel = new AgeClassifyModel(new File(pathToClassifyModel));
        this.classify = new AgeClassifyME(this.classifyModel);
        this.model = AgePredictModel.readModel((File)new File(pathToRegressionModel));
    }

    public double predictAge(String document) throws InvalidFormatException, IOException {
        FeatureGenerator[] featureGenerators = this.model.getContext().getFeatureGenerators();
        ArrayList<Row> data = new ArrayList<Row>();
        String[] tokens = this.model.getContext().getTokenizer().tokenize(document);
        double[] prob = this.classify.getProbabilities(tokens);
        String category = this.classify.getBestCategory(prob);
        ArrayList<String> context = new ArrayList<String>();
        for (FeatureGenerator featureGenerator : featureGenerators) {
            Collection extractedFeatures = featureGenerator.extractFeatures(tokens);
            context.addAll(extractedFeatures);
        }
        if (category != null) {
            for (int i = 0; i < tokens.length / 18; ++i) {
                context.add("cat=" + category);
            }
        }
        if (context.size() > 0) {
            data.add(RowFactory.create((Object[])new Object[]{document, context.toArray()}));
        }
        StructType schema = new StructType(new StructField[]{new StructField("document", DataTypes.StringType, false, Metadata.empty()), new StructField("text", (DataType)new ArrayType(DataTypes.StringType, true), false, Metadata.empty())});
        Dataset df = this.spark.createDataFrame(data, schema);
        CountVectorizerModel cvm = new CountVectorizerModel(this.model.getVocabulary()).setInputCol("text").setOutputCol("feature");
        Dataset eventDF = cvm.transform(df);
        Normalizer normalizer = ((Normalizer)((Normalizer)new Normalizer().setInputCol("feature")).setOutputCol("normFeature")).setP(1.0);
        JavaRDD normEventDF = normalizer.transform(eventDF).javaRDD();
        Row event = (Row)normEventDF.first();
        SparseVector sp = (SparseVector)event.getAs("normFeature");
        LassoModel linModel = this.model.getModel();
        Vector testData = Vectors.sparse((int)sp.size(), (int[])sp.indices(), (double[])sp.values());
        return linModel.predict(testData.compressed());
    }

    public static void main(String[] args) throws Exception {
        String inputText = "I am very very old person";
        if (args.length > 0) {
            StringBuilder builder = new StringBuilder();
            for (String s : args) {
                builder.append(s);
                builder.append(" ");
            }
            inputText = builder.toString();
        }
        double age = new AgePredicterLocal().predictAge(inputText);
        System.out.println("\n===================\n");
        System.out.println(String.format("Text received- '%s' \n Predicted Age - %f%n", inputText, age));
        System.out.println("\n===================\n");
    }
}

