/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.neuralNetworks;

import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.optimization.termination.TerminationCondition;
import de.jstacs.classifiers.AbstractScoreBasedClassifier;
import de.jstacs.classifiers.neuralNetworks.activationFunctions.ActivationFunction;
import de.jstacs.classifiers.neuralNetworks.neurons.InnerNeuron;
import de.jstacs.classifiers.neuralNetworks.neurons.InputNeuron;
import de.jstacs.classifiers.neuralNetworks.neurons.MSEOutputNeuron;
import de.jstacs.classifiers.neuralNetworks.neurons.Neuron;
import de.jstacs.classifiers.neuralNetworks.stepSizeAdaption.StepSizeAdaption;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.utils.REnvironment;
import de.jstacs.utils.Time;
import java.util.Random;

public class NeuralNetworkClassifier
extends AbstractScoreBasedClassifier {
    private static Random r = new Random();
    private Neuron[][] neurons;
    private boolean isInitialized;
    private int[][] freeIndexes;
    private int lastIndex;
    private Learning learning;
    private StepSizeAdaption stepSize;
    private ActivationFunction activationFunction;
    private double initStepSize;
    private TerminationCondition termCond;

    public NeuralNetworkClassifier(AlphabetContainer abc, int[] numNeurons, ActivationFunction activationFunction, Learning learning, StepSizeAdaption stepSize, double initStepSize, TerminationCondition termCond) {
        super(abc, numNeurons[numNeurons.length - 1]);
        this.neurons = new Neuron[numNeurons.length][];
        int i = 0;
        while (i < numNeurons.length) {
            this.neurons[i] = new Neuron[numNeurons[i]];
            int j = 0;
            while (j < numNeurons[i]) {
                this.neurons[i][j] = i == 0 ? new InputNeuron(j) : (i == numNeurons.length - 1 ? new MSEOutputNeuron(activationFunction, j, this.neurons[i - 1]) : new InnerNeuron(activationFunction, j, this.neurons[i - 1]));
                ++j;
            }
            ++i;
        }
        this.activationFunction = activationFunction;
        this.learning = learning;
        this.stepSize = stepSize;
        this.initStepSize = initStepSize;
        this.termCond = termCond;
        this.isInitialized = false;
    }

    public NeuralNetworkClassifier(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    protected double getScore(Sequence seq, int i, boolean check) throws IllegalArgumentException, NotTrainedException, Exception {
        return this.neurons[this.neurons.length - 1][i].getOutput(seq);
    }

    @Override
    public String getInstanceName() {
        return "Neural network classifier";
    }

    @Override
    public CategoricalResult[] getClassifierAnnotation() {
        return null;
    }

    @Override
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        return null;
    }

    @Override
    public boolean isInitialized() {
        return this.isInitialized;
    }

    protected int[][] getIndexes() {
        if (this.learning == Learning.BATCH) {
            return this.freeIndexes;
        }
        int d = r.nextInt(this.lastIndex);
        int[][] temp = new int[][]{this.freeIndexes[d]};
        --this.lastIndex;
        this.freeIndexes[d] = this.freeIndexes[this.lastIndex];
        this.freeIndexes[this.lastIndex] = temp[0];
        if (this.lastIndex == 0) {
            this.lastIndex = this.freeIndexes.length;
        }
        return temp;
    }

    protected void prepareIndexes(DataSet[] data) {
        int num = 0;
        int i = 0;
        while (i < data.length) {
            num += data[i].getNumberOfElements();
            ++i;
        }
        this.freeIndexes = new int[num][2];
        i = 0;
        int j = 0;
        while (i < data.length) {
            int k = 0;
            while (k < data[i].getNumberOfElements()) {
                this.freeIndexes[j][0] = i;
                this.freeIndexes[j][1] = k++;
                ++j;
            }
            ++i;
        }
        this.lastIndex = this.freeIndexes.length;
    }

    @Override
    public void train(DataSet[] data, double[][] weights) throws Exception {
        double totalError;
        int i = 0;
        while (i < this.neurons.length) {
            int j = 0;
            while (j < this.neurons[i].length) {
                this.neurons[i][j].initializeRandomly();
                ++j;
            }
            ++i;
        }
        this.isInitialized = true;
        double[][] desiredOutputs = new double[data.length][data.length];
        int i2 = 0;
        while (i2 < desiredOutputs.length) {
            int j = 0;
            while (j < desiredOutputs[i2].length) {
                desiredOutputs[i2][j] = i2 == j ? this.activationFunction.getPositiveValue() : this.activationFunction.getNegativeValue();
                ++j;
            }
            ++i2;
        }
        this.prepareIndexes(data);
        int epoch = 0;
        int iteration = 1;
        int num = 0;
        double lastTotalError = totalError = Double.POSITIVE_INFINITY;
        double currStepSize = this.initStepSize;
        Time time = Time.getTimeInstance(null);
        REnvironment re = new REnvironment();
        double[][] matrix = new double[data[0].getNumberOfElements() + data[1].getNumberOfElements()][2];
        int i3 = 0;
        int k = 0;
        while (i3 < data.length) {
            int j = 0;
            while (j < data[i3].getNumberOfElements()) {
                matrix[k][0] = data[i3].getElementAt(j).continuousVal(0);
                matrix[k][1] = data[i3].getElementAt(j).continuousVal(1);
                ++j;
                ++k;
            }
            ++i3;
        }
        re.createMatrix("data", matrix);
        re.voidEval("pdf(\"/Users/dev/Desktop/old/Lehre/Muster_WS12/Uebung11/plots.pdf\");");
        double[] outs = new double[matrix.length];
        do {
            if (epoch % 100 == 0) {
                System.out.println(String.valueOf(epoch) + "\t" + totalError + "\t" + currStepSize + "\t" + time.getElapsedTime());
                re.createVector("outs", outs);
                if (epoch > 0) {
                    re.voidEval("outs<-(outs-min(outs))/(max(outs)-min(outs));");
                }
                re.voidEval("plot(data[,1],data[,2],col=rgb(outs,0,0));");
            }
            num = this.freeIndexes.length;
            ++epoch;
            lastTotalError = totalError;
            totalError = 0.0;
            while (num > 0) {
                int[][] idxs = this.getIndexes();
                int n = 0;
                while (n < idxs.length) {
                    int i4 = 0;
                    while (i4 < this.neurons.length) {
                        int j = 0;
                        while (j < this.neurons[i4].length) {
                            this.neurons[i4][j].reset();
                            ++j;
                        }
                        ++i4;
                    }
                    Sequence input = data[idxs[n][0]].getElementAt(idxs[n][1]);
                    double weight = weights == null || weights[idxs[n][0]] == null ? 1.0 : weights[idxs[n][0]][idxs[n][1]];
                    int i5 = 0;
                    while (i5 < this.neurons[this.neurons.length - 1].length) {
                        double out = this.neurons[this.neurons.length - 1][i5].getOutput(input);
                        outs[(idxs[n][0] == 0 ? 0 : data[0].getNumberOfElements()) + idxs[n][1]] = out;
                        totalError += 0.5 * weight * (desiredOutputs[idxs[n][0]][i5] - out) * (desiredOutputs[idxs[n][0]][i5] - out);
                        ++i5;
                    }
                    i5 = 0;
                    while (i5 < this.neurons[0].length) {
                        this.neurons[0][i5].getError(input, weight, desiredOutputs[idxs[n][0]]);
                        ++i5;
                    }
                    ++n;
                }
                int i6 = 0;
                while (i6 < this.neurons.length) {
                    int j = 0;
                    while (j < this.neurons[i6].length) {
                        this.neurons[i6][j].adaptWeights(currStepSize);
                        ++j;
                    }
                    ++i6;
                }
                num -= idxs.length;
                currStepSize = this.stepSize.getStepSize(this.initStepSize, currStepSize, ++iteration, epoch);
            }
        } while (this.termCond.doNextIteration(epoch, lastTotalError, totalError, null, null, currStepSize, time));
        re.voidEval("dev.off()");
        System.out.println(String.valueOf(epoch) + "\t" + totalError + "\t" + currStepSize + "\t" + time.getElapsedTime());
        this.freeIndexes = null;
    }

    @Override
    protected String getXMLTag() {
        return this.getClass().getSimpleName();
    }

    @Override
    protected StringBuffer getFurtherClassifierInfos() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.activationFunction, "activationFunction");
        XMLParser.appendObjectWithTags(xml, this.initStepSize, "initStepSize");
        XMLParser.appendObjectWithTags(xml, this.isInitialized, "isInitialized");
        XMLParser.appendObjectWithTags(xml, (Object)this.learning, "learning");
        XMLParser.appendObjectWithTags(xml, this.neurons, "neurons");
        XMLParser.appendObjectWithTags(xml, this.stepSize, "stepSize");
        XMLParser.appendObjectWithTags(xml, this.termCond, "termCond");
        return xml;
    }

    @Override
    protected void extractFurtherClassifierInfosFromXML(StringBuffer xml) throws NonParsableException {
        this.activationFunction = XMLParser.extractObjectForTags(xml, "activationFunction", ActivationFunction.class);
        this.initStepSize = XMLParser.extractObjectForTags(xml, "initStepSize", Double.TYPE);
        this.isInitialized = XMLParser.extractObjectForTags(xml, "isInitialized", Boolean.TYPE);
        this.learning = XMLParser.extractObjectForTags(xml, "learning", Learning.class);
        this.neurons = XMLParser.extractObjectForTags(xml, "neurons", Neuron[][].class);
        this.stepSize = XMLParser.extractObjectForTags(xml, "stepSize", StepSizeAdaption.class);
        this.termCond = XMLParser.extractObjectForTags(xml, "termCond", TerminationCondition.class);
        int i = 1;
        while (i < this.neurons.length) {
            int j = 0;
            while (j < this.neurons[i].length) {
                ((InnerNeuron)this.neurons[i][j]).setPredecessors(this.neurons[i - 1]);
                ++j;
            }
            ++i;
        }
    }

    public static enum Learning {
        BATCH,
        STOCHASTIC;

    }
}

