/*
 * Decompiled with CFR 0.152.
 */
package org.cleartk.ml.feature.selection;

import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.base.Predicate;
import com.google.common.collect.Collections2;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
import com.google.common.collect.Table;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.extractor.CleartkExtractorException;
import org.cleartk.ml.feature.extractor.FeatureExtractor1;
import org.cleartk.ml.feature.selection.FeatureSelectionExtractor;
import org.cleartk.ml.feature.transform.TransformableFeature;

public class MutualInformationFeatureSelectionExtractor<OUTCOME_T, FOCUS_T extends Annotation>
extends FeatureSelectionExtractor<OUTCOME_T>
implements FeatureExtractor1<FOCUS_T> {
    protected boolean isTrained;
    private MutualInformationStats<OUTCOME_T> mutualInfoStats;
    private FeatureExtractor1<FOCUS_T> subExtractor;
    private int numFeatures;
    private CombineScoreMethod combineScoreMethod;
    private List<String> selectedFeatures;
    private double smoothingCount;

    public String nameFeature(Feature feature) {
        return feature.getValue() instanceof Number ? feature.getName() : feature.getName() + ":" + feature.getValue();
    }

    public MutualInformationFeatureSelectionExtractor(String name, FeatureExtractor1<FOCUS_T> extractor) {
        super(name);
        this.init(extractor, CombineScoreMethod.MAX, 1.0, 10);
    }

    public MutualInformationFeatureSelectionExtractor(String name, FeatureExtractor1<FOCUS_T> extractor, int numFeatures) {
        super(name);
        this.init(extractor, CombineScoreMethod.MAX, 1.0, numFeatures);
    }

    public MutualInformationFeatureSelectionExtractor(String name, FeatureExtractor1<FOCUS_T> extractor, CombineScoreMethod combineMeasureType, double smoothingCount, int numFeatures) {
        super(name);
        this.init(extractor, combineMeasureType, smoothingCount, numFeatures);
    }

    private void init(FeatureExtractor1<FOCUS_T> extractor, CombineScoreMethod method, double smoothCount, int n) {
        this.subExtractor = extractor;
        this.combineScoreMethod = method;
        this.smoothingCount = smoothCount;
        this.numFeatures = n;
    }

    @Override
    public List<Feature> extract(JCas view, FOCUS_T focusAnnotation) throws CleartkExtractorException {
        List<Feature> extracted = this.subExtractor.extract(view, focusAnnotation);
        ArrayList<Feature> result = new ArrayList<Feature>();
        if (this.isTrained) {
            result.addAll(Collections2.filter(extracted, (Predicate)this));
        } else {
            result.add(new TransformableFeature(this.name, extracted));
        }
        return result;
    }

    @Override
    public void train(Iterable<Instance<OUTCOME_T>> instances) {
        this.mutualInfoStats = new MutualInformationStats(this.smoothingCount);
        for (Instance<OUTCOME_T> instance : instances) {
            OUTCOME_T outcome = instance.getOutcome();
            for (Feature feature : instance.getFeatures()) {
                if (!this.isTransformable(feature)) continue;
                for (Feature untransformedFeature : ((TransformableFeature)feature).getFeatures()) {
                    this.mutualInfoStats.update(this.nameFeature(untransformedFeature), outcome, 1);
                }
            }
        }
        Set featureNames = this.mutualInfoStats.classConditionalCounts.rowKeySet();
        this.selectedFeatures = Ordering.natural().onResultOf(this.mutualInfoStats.getScoreFunction(this.combineScoreMethod)).reverse().immutableSortedCopy((Iterable)featureNames);
        this.isTrained = true;
    }

    @Override
    public void save(URI uri) throws IOException {
        if (!this.isTrained) {
            throw new IOException("MutualInformationFeatureExtractor: Cannot save before training.");
        }
        File out = new File(uri);
        BufferedWriter writer = new BufferedWriter(new FileWriter(out));
        writer.append("CombineScoreType\t");
        writer.append(this.combineScoreMethod.toString());
        writer.append("\n");
        MutualInformationStats.ComputeFeatureScore<OUTCOME_T> computeScore = this.mutualInfoStats.getScoreFunction(this.combineScoreMethod);
        for (String feature : this.selectedFeatures) {
            writer.append(String.format(Locale.ROOT, "%s\t%f\n", feature, computeScore.apply(feature)));
        }
        writer.close();
    }

    @Override
    public void load(URI uri) throws IOException {
        this.selectedFeatures = Lists.newArrayList();
        File in = new File(uri);
        BufferedReader reader = new BufferedReader(new FileReader(in));
        this.combineScoreMethod = CombineScoreMethod.valueOf(reader.readLine().split("\\t")[1]);
        String line = null;
        for (int n = 0; (line = reader.readLine()) != null && n < this.numFeatures; ++n) {
            String[] featureValuePair = line.split("\\t");
            this.selectedFeatures.add(featureValuePair[0]);
        }
        reader.close();
        this.isTrained = true;
    }

    public boolean apply(Feature feature) {
        return this.selectedFeatures.contains(this.nameFeature(feature));
    }

    public final List<String> getSelectedFeatures() {
        return this.selectedFeatures;
    }

    public static class MutualInformationStats<OUTCOME_T> {
        protected Multiset<OUTCOME_T> classCounts = HashMultiset.create();
        protected Table<String, OUTCOME_T, Integer> classConditionalCounts = HashBasedTable.create();
        protected double smoothingCount;

        public MutualInformationStats(double smoothingCount) {
            this.smoothingCount += smoothingCount;
        }

        public void update(String featureName, OUTCOME_T outcome, int occurrences) {
            Integer count = (Integer)this.classConditionalCounts.get((Object)featureName, outcome);
            if (count == null) {
                count = 0;
            }
            this.classConditionalCounts.put((Object)featureName, outcome, (Object)(count + occurrences));
            this.classCounts.add(outcome, occurrences);
        }

        public double mutualInformation(String featureName, OUTCOME_T outcome) {
            int[] featureCounts = new int[2];
            int[] outcomeCounts = new int[2];
            int[][] featureOutcomeCounts = new int[2][2];
            int n = this.classCounts.size();
            featureCounts[1] = this.sum(this.classConditionalCounts.row((Object)featureName).values());
            featureCounts[0] = n - featureCounts[1];
            outcomeCounts[1] = this.classCounts.count(outcome);
            outcomeCounts[0] = n - outcomeCounts[1];
            featureOutcomeCounts[1][1] = this.classConditionalCounts.contains((Object)featureName, outcome) ? (Integer)this.classConditionalCounts.get((Object)featureName, outcome) : 0;
            featureOutcomeCounts[1][0] = featureCounts[1] - featureOutcomeCounts[1][1];
            featureOutcomeCounts[0][1] = outcomeCounts[1] - featureOutcomeCounts[1][1];
            featureOutcomeCounts[0][0] = n - featureCounts[1] - outcomeCounts[1] + featureOutcomeCounts[1][1];
            double information = 0.0;
            for (int nFeature = 0; nFeature <= 1; ++nFeature) {
                for (int nOutcome = 0; nOutcome <= 1; ++nOutcome) {
                    int[] nArray = featureOutcomeCounts[nFeature];
                    int n2 = nOutcome;
                    nArray[n2] = (int)((double)nArray[n2] + this.smoothingCount);
                    information += (double)featureOutcomeCounts[nFeature][nOutcome] / (double)n * Math.log((double)n * (double)featureOutcomeCounts[nFeature][nOutcome] / ((double)featureCounts[nFeature] * (double)outcomeCounts[nOutcome]));
                }
            }
            return information;
        }

        private int sum(Collection<Integer> values) {
            int total = 0;
            for (int v : values) {
                total += v;
            }
            return total;
        }

        public void save(URI outputURI) throws IOException {
            File out = new File(outputURI);
            BufferedWriter writer = null;
            writer = new BufferedWriter(new FileWriter(out));
            writer.append("Mutual Information Data\n");
            writer.append("Feature\t");
            writer.append(Joiner.on((String)"\t").join((Iterable)this.classConditionalCounts.columnKeySet()));
            writer.append("\n");
            for (String featureName : this.classConditionalCounts.rowKeySet()) {
                writer.append(featureName);
                for (Object outcome : this.classConditionalCounts.columnKeySet()) {
                    writer.append("\t");
                    writer.append(String.format(Locale.ROOT, "%f", this.mutualInformation(featureName, outcome)));
                }
                writer.append("\n");
            }
            writer.append("\n");
            writer.append(this.classConditionalCounts.toString());
            writer.close();
        }

        public ComputeFeatureScore<OUTCOME_T> getScoreFunction(CombineScoreMethod combineScoreMethod) {
            return new ComputeFeatureScore(this, combineScoreMethod);
        }

        public static class ComputeFeatureScore<OUTCOME_T>
        implements Function<String, Double> {
            private MutualInformationStats<OUTCOME_T> stats;
            private CombineScoreMethod.CombineScoreFunction<OUTCOME_T> combineScoreFunction;

            public ComputeFeatureScore(MutualInformationStats<OUTCOME_T> stats, CombineScoreMethod combineMeasureType) {
                this.stats = stats;
                switch (combineMeasureType) {
                    case AVERAGE: {
                        this.combineScoreFunction = new CombineScoreMethod.AverageScores();
                    }
                    case MAX: {
                        this.combineScoreFunction = new CombineScoreMethod.MaxScores();
                    }
                }
            }

            public Double apply(String featureName) {
                Set outcomes = this.stats.classConditionalCounts.columnKeySet();
                HashMap featureOutcomeMI = Maps.newHashMap();
                for (Object outcome : outcomes) {
                    featureOutcomeMI.put(outcome, this.stats.mutualInformation(featureName, outcome));
                }
                return (Double)this.combineScoreFunction.apply(featureOutcomeMI);
            }
        }
    }

    public static enum CombineScoreMethod {
        AVERAGE,
        MAX;


        public static class MaxScores<OUTCOME_T>
        extends CombineScoreFunction<OUTCOME_T> {
            public Double apply(Map<OUTCOME_T, Double> input) {
                return (Double)Ordering.natural().max(input.values());
            }
        }

        public static class AverageScores<OUTCOME_T>
        extends CombineScoreFunction<OUTCOME_T> {
            public Double apply(Map<OUTCOME_T, Double> input) {
                Collection<Double> scores = input.values();
                int size = scores.size();
                double total = 0.0;
                for (Double score : scores) {
                    total += score.doubleValue();
                }
                return total / (double)size;
            }
        }

        public static abstract class CombineScoreFunction<OUTCOME_T>
        implements Function<Map<OUTCOME_T, Double>, Double> {
        }
    }
}

