/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.core.learning;

import java.io.Serializable;
import java.util.Iterator;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.IterativeLearning;
import org.neuroph.core.learning.error.ErrorFunction;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.core.learning.stop.MaxErrorStop;

public abstract class SupervisedLearning
extends IterativeLearning
implements Serializable {
    private static final long serialVersionUID = 3L;
    protected transient double previousEpochError;
    protected double maxError = 0.01;
    private double minErrorChange = Double.POSITIVE_INFINITY;
    private int minErrorChangeIterationsLimit = Integer.MAX_VALUE;
    private transient int minErrorChangeIterationsCount;
    private boolean batchMode = false;
    private ErrorFunction errorFunction = new MeanSquaredError();

    public SupervisedLearning() {
        this.stopConditions.add(new MaxErrorStop(this));
    }

    public void learn(DataSet trainingSet, double maxError) {
        this.maxError = maxError;
        this.learn(trainingSet);
    }

    public void learn(DataSet trainingSet, double maxError, int maxIterations) {
        this.maxError = maxError;
        this.setMaxIterations(maxIterations);
        this.learn(trainingSet);
    }

    @Override
    protected void onStart() {
        super.onStart();
        this.minErrorChangeIterationsCount = 0;
        this.previousEpochError = 0.0;
    }

    @Override
    protected void beforeEpoch() {
        this.previousEpochError = this.errorFunction.getTotalError();
        this.errorFunction.reset();
    }

    @Override
    protected void afterEpoch() {
        double absErrorChange = Math.abs(this.previousEpochError - this.errorFunction.getTotalError());
        this.minErrorChangeIterationsCount = absErrorChange <= this.minErrorChange ? ++this.minErrorChangeIterationsCount : 0;
        if (this.batchMode) {
            this.doBatchWeightsUpdate();
        }
    }

    @Override
    public void doLearningEpoch(DataSet trainingSet) {
        Iterator<DataSetRow> iterator = trainingSet.iterator();
        while (iterator.hasNext() && !this.isStopped()) {
            DataSetRow dataSetRow = iterator.next();
            this.learnPattern(dataSetRow);
        }
    }

    protected void learnPattern(DataSetRow trainingElement) {
        double[] input = trainingElement.getInput();
        this.neuralNetwork.setInput(input);
        this.neuralNetwork.calculate();
        double[] output = this.neuralNetwork.getOutput();
        double[] desiredOutput = trainingElement.getDesiredOutput();
        double[] patternError = this.errorFunction.calculatePatternError(output, desiredOutput);
        this.updateNetworkWeights(patternError);
    }

    protected void doBatchWeightsUpdate() {
        Layer[] layers = this.neuralNetwork.getLayers();
        for (int i = this.neuralNetwork.getLayersCount() - 1; i > 0; --i) {
            for (Neuron neuron : layers[i].getNeurons()) {
                for (Connection connection : neuron.getInputConnections()) {
                    Weight weight = connection.getWeight();
                    weight.value += weight.weightChange;
                    weight.weightChange = 0.0;
                }
            }
        }
    }

    public boolean isInBatchMode() {
        return this.batchMode;
    }

    public void setBatchMode(boolean batchMode) {
        this.batchMode = batchMode;
    }

    public void setMaxError(double maxError) {
        this.maxError = maxError;
    }

    public double getMaxError() {
        return this.maxError;
    }

    public double getPreviousEpochError() {
        return this.previousEpochError;
    }

    public double getMinErrorChange() {
        return this.minErrorChange;
    }

    public void setMinErrorChange(double minErrorChange) {
        this.minErrorChange = minErrorChange;
    }

    public int getMinErrorChangeIterationsLimit() {
        return this.minErrorChangeIterationsLimit;
    }

    public void setMinErrorChangeIterationsLimit(int minErrorChangeIterationsLimit) {
        this.minErrorChangeIterationsLimit = minErrorChangeIterationsLimit;
    }

    public int getMinErrorChangeIterationsCount() {
        return this.minErrorChangeIterationsCount;
    }

    public ErrorFunction getErrorFunction() {
        return this.errorFunction;
    }

    public void setErrorFunction(ErrorFunction errorFunction) {
        this.errorFunction = errorFunction;
    }

    public double getTotalNetworkError() {
        return this.errorFunction.getTotalError();
    }

    protected abstract void updateNetworkWeights(double[] var1);
}

