/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.assessment;

import de.jstacs.classifiers.AbstractClassifier;
import de.jstacs.classifiers.ClassDimensionException;
import de.jstacs.classifiers.assessment.ClassifierAssessment;
import de.jstacs.classifiers.assessment.ClassifierAssessmentAssessParameterSet;
import de.jstacs.classifiers.assessment.KFoldCrossValidationAssessParameterSet;
import de.jstacs.classifiers.performanceMeasures.NumericalPerformanceMeasureParameterSet;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.ListResult;
import de.jstacs.results.MeanResultSet;
import de.jstacs.results.Result;
import de.jstacs.results.ResultSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel;
import de.jstacs.utils.ProgressUpdater;
import java.util.Arrays;
import java.util.LinkedList;

public class KFoldCrossValidation
extends ClassifierAssessment {
    protected KFoldCrossValidation(AbstractClassifier[] aCs, TrainableStatisticalModel[][] aMs, boolean buildClassifiersByCrossProduct, boolean checkAlphabetConsistencyAndLength) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException, ClassDimensionException {
        super(aCs, aMs, buildClassifiersByCrossProduct, checkAlphabetConsistencyAndLength);
    }

    public KFoldCrossValidation(AbstractClassifier ... aCs) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException, ClassDimensionException {
        super(aCs);
    }

    public KFoldCrossValidation(boolean buildClassifiersByCrossProduct, TrainableStatisticalModel[] ... aMs) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException, ClassDimensionException {
        super(buildClassifiersByCrossProduct, aMs);
    }

    public KFoldCrossValidation(AbstractClassifier[] aCs, boolean buildClassifiersByCrossProduct, TrainableStatisticalModel[] ... aMs) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException, ClassDimensionException {
        super(aCs, buildClassifiersByCrossProduct, aMs);
    }

    @Override
    protected void evaluateClassifier(NumericalPerformanceMeasureParameterSet mp, ClassifierAssessmentAssessParameterSet assessPS, DataSet[] s, ProgressUpdater pU) throws IllegalArgumentException, Exception {
        KFoldCrossValidationAssessParameterSet tempAssessPS = null;
        try {
            tempAssessPS = (KFoldCrossValidationAssessParameterSet)assessPS;
        }
        catch (ClassCastException e) {
            throw new IllegalArgumentException("Given AssessParameterSet is not of type KFoldCVAssessParameterSet.");
        }
        DataSet.PartitionMethod splitMethod = tempAssessPS.getDataSplitMethod();
        int k = tempAssessPS.getK();
        DataSet[][] sInParts = new DataSet[s.length][];
        try {
            int i = 0;
            while (i < sInParts.length) {
                sInParts[i] = s[i++].partition(k, splitMethod);
            }
        }
        catch (EmptyDataSetException e) {
            throw new IllegalArgumentException("Given DataSet s seems to contain to few elements for a " + k + "-fold crossvalidation since at least one empty subset occured " + "during splitting given data into " + k + " non-overlapping parts.");
        }
        this.evaluate(mp, tempAssessPS, pU, sInParts);
    }

    private void evaluate(NumericalPerformanceMeasureParameterSet mp, ClassifierAssessmentAssessParameterSet caaps, ProgressUpdater pU, DataSet[] ... splitData) throws Exception {
        int j;
        int subSeqL = caaps.getElementLength();
        boolean exceptionIfMPNotComputable = caaps.getExceptionIfMPNotComputable();
        int clazz = splitData.length;
        int k = splitData[0].length;
        DataSet[][] sTrainTestClassWise = new DataSet[2][clazz];
        boolean[] tempBool = new boolean[k];
        Arrays.fill(tempBool, true);
        for (j = 1; j < splitData.length && splitData[j].length == k; ++j) {
        }
        if (j != splitData.length) {
            throw new IllegalArgumentException("Please check the number of predefined splits per class. Compare class 0 with class " + j);
        }
        pU.setMax(k);
        for (int i = 0; i < k; ++i) {
            tempBool[i] = false;
            for (j = 0; j < clazz; ++j) {
                sTrainTestClassWise[0][j] = DataSet.union(splitData[j], tempBool);
                sTrainTestClassWise[1][j] = new DataSet(splitData[j][i], subSeqL);
            }
            tempBool[i] = true;
            this.train(sTrainTestClassWise[0]);
            this.test(mp, exceptionIfMPNotComputable, sTrainTestClassWise[1]);
            pU.setValue(i + 1);
        }
    }

    public ListResult assessWithPredefinedSplits(NumericalPerformanceMeasureParameterSet mp, ClassifierAssessmentAssessParameterSet caaps, ProgressUpdater pU, DataSet[] ... splitData) throws Exception {
        int clazz = this.myAbstractClassifier[0].getNumberOfClasses();
        if (splitData.length != clazz) {
            throw new IllegalArgumentException("The number of classes in the data array and the classifier differs.");
        }
        this.myTempMeanResultSets = new MeanResultSet[this.myAbstractClassifier.length];
        for (int i = 0; i < this.myAbstractClassifier.length; ++i) {
            this.myTempMeanResultSets[i] = new MeanResultSet(this.myAbstractClassifier[i].getClassifierAnnotation());
        }
        this.evaluate(mp, caaps, pU, splitData);
        LinkedList<Result> annotation = new LinkedList<Result>();
        annotation.add(new CategoricalResult("kind of assessment", "a description or name of the assessment", this.getNameOfAssessment()));
        annotation.addAll(caaps.getAnnotation());
        StringBuffer sb = new StringBuffer(1000);
        sb.append("[" + DataSet.getAnnotation(splitData[0]));
        for (int i = 1; i < splitData.length; ++i) {
            sb.append(", " + DataSet.getAnnotation(splitData[i]));
        }
        sb.append("]");
        annotation.add(new CategoricalResult("samples", "annotation of used samples", "predefined splits: " + sb));
        return new ListResult(this.getNameOfAssessment(), "the results of a " + this.getNameOfAssessment() + " of predefined splits", new ResultSet(annotation), this.myTempMeanResultSets);
    }
}

