/*
 * Decompiled with CFR 0.152.
 */
package org.cleartk.ml.feature.transform.extractor;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.extractor.CleartkExtractorException;
import org.cleartk.ml.feature.extractor.FeatureExtractor1;
import org.cleartk.ml.feature.transform.OneToOneTrainableExtractor_ImplBase;
import org.cleartk.ml.feature.transform.TransformableFeature;

public class MinMaxNormalizationExtractor<OUTCOME_T, FOCUS_T extends Annotation>
extends OneToOneTrainableExtractor_ImplBase<OUTCOME_T>
implements FeatureExtractor1<FOCUS_T> {
    private FeatureExtractor1<FOCUS_T> subExtractor;
    private boolean isTrained;
    private Map<String, MinMaxPair> minMaxMap;

    public MinMaxNormalizationExtractor(String name) {
        this(name, null);
    }

    public MinMaxNormalizationExtractor(String name, FeatureExtractor1<FOCUS_T> subExtractor) {
        super(name);
        this.subExtractor = subExtractor;
        this.isTrained = false;
    }

    @Override
    protected Feature transform(Feature feature) {
        String featureName = feature.getName();
        MinMaxPair stats = this.minMaxMap.get(featureName);
        double mmn = 0.5;
        double value = ((Number)feature.getValue()).doubleValue();
        if (stats != null && stats.min < stats.max) {
            mmn = (value - stats.min) / (stats.max - stats.min);
        }
        if (stats != null && stats.min == stats.max) {
            mmn = value == stats.min ? 0.5 : (value < stats.min ? 0.0 : 1.0);
        }
        mmn = Math.max(0.0, mmn);
        mmn = Math.min(1.0, mmn);
        return new Feature("MINMAX_NORMED_" + featureName, mmn);
    }

    @Override
    public List<Feature> extract(JCas view, FOCUS_T focusAnnotation) throws CleartkExtractorException {
        List<Feature> extracted = this.subExtractor.extract(view, focusAnnotation);
        ArrayList<Feature> result = new ArrayList<Feature>();
        if (this.isTrained) {
            for (Feature feature : extracted) {
                result.add(this.transform(feature));
            }
        } else {
            result.add(new TransformableFeature(this.name, extracted));
        }
        return result;
    }

    @Override
    public void train(Iterable<Instance<OUTCOME_T>> instances) {
        HashMap<String, MinMaxRunningStat> featureStatsMap = new HashMap<String, MinMaxRunningStat>();
        for (Instance<OUTCOME_T> instance : instances) {
            for (Feature feature : instance.getFeatures()) {
                if (!this.isTransformable(feature)) continue;
                for (Feature untransformedFeature : ((TransformableFeature)feature).getFeatures()) {
                    String featureName = untransformedFeature.getName();
                    Object featureValue = untransformedFeature.getValue();
                    if (featureValue instanceof Number) {
                        MinMaxRunningStat stats;
                        if (featureStatsMap.containsKey(featureName)) {
                            stats = (MinMaxRunningStat)featureStatsMap.get(featureName);
                        } else {
                            stats = new MinMaxRunningStat();
                            featureStatsMap.put(featureName, stats);
                        }
                        stats.add(((Number)featureValue).doubleValue());
                        continue;
                    }
                    throw new IllegalArgumentException("Cannot normalize non-numeric feature values");
                }
            }
        }
        this.minMaxMap = new HashMap<String, MinMaxPair>();
        for (Map.Entry entry : featureStatsMap.entrySet()) {
            MinMaxRunningStat stats = (MinMaxRunningStat)entry.getValue();
            this.minMaxMap.put((String)entry.getKey(), new MinMaxPair(stats.min(), stats.max()));
        }
        this.isTrained = true;
    }

    @Override
    public void save(URI zmusDataUri) throws IOException {
        File out = new File(zmusDataUri);
        BufferedWriter writer = null;
        writer = new BufferedWriter(new FileWriter(out));
        for (Map.Entry<String, MinMaxPair> entry : this.minMaxMap.entrySet()) {
            MinMaxPair pair = entry.getValue();
            writer.append(String.format(Locale.ROOT, "%s\t%f\t%f\n", entry.getKey(), pair.min, pair.max));
        }
        writer.close();
    }

    @Override
    public void load(URI zmusDataUri) throws IOException {
        File in = new File(zmusDataUri);
        BufferedReader reader = null;
        this.minMaxMap = new HashMap<String, MinMaxPair>();
        reader = new BufferedReader(new FileReader(in));
        String line = null;
        while ((line = reader.readLine()) != null) {
            String[] featureMeanStddev = line.split("\\t");
            this.minMaxMap.put(featureMeanStddev[0], new MinMaxPair(Double.parseDouble(featureMeanStddev[1]), Double.parseDouble(featureMeanStddev[2])));
        }
        reader.close();
        this.isTrained = true;
    }

    public static class MinMaxRunningStat
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private double min;
        private double max;
        private int n;

        public MinMaxRunningStat() {
            this.clear();
        }

        public void add(double x) {
            ++this.n;
            if (x < this.min) {
                this.min = x;
            }
            if (x > this.max) {
                this.max = x;
            }
        }

        public void clear() {
            this.n = 0;
            this.min = Double.MAX_VALUE;
            this.max = Double.MIN_VALUE;
        }

        public int getNumSamples() {
            return this.n;
        }

        public double min() {
            return this.min;
        }

        public double max() {
            return this.max;
        }
    }

    private static class MinMaxPair {
        public double min;
        public double max;

        public MinMaxPair(double min, double max) {
            this.min = min;
            this.max = max;
        }
    }
}

