package projects.plantdream;

import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.classifiers.performanceMeasures.AucPR;
import de.jstacs.classifiers.performanceMeasures.AucROC;
import de.jstacs.classifiers.performanceMeasures.NumericalPerformanceMeasureParameterSet;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.io.SparseStringExtractor;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.GaussianNetwork;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.util.LinkedList;

/* loaded from: input_file:projects/plantdream/TrainUnsupervisedAlternate.class */
public class TrainUnsupervisedAlternate {
    public static void main(String[] strArr) throws Exception {
        DataSet dataSet;
        AlphabetContainer alphabetContainer = new AlphabetContainer(new ContinuousAlphabet());
        DataSet dataSet2 = new DataSet(alphabetContainer, new SparseStringExtractor(strArr[0], '>'), "\t");
        DataSet dataSet3 = new DataSet(alphabetContainer, new SparseStringExtractor(strArr[1], '>'), "\t");
        DataSet union = DataSet.union(dataSet2, dataSet3);
        DataSet compositeDataSet = union.getCompositeDataSet(new int[]{9, 20, 31}, new int[]{2, 2, 2});
        DataSet compositeDataSet2 = union.getCompositeDataSet(new int[]{0, 11, 22}, new int[]{9, 9, 9});
        Pair<DataSet[], double[]> split = split(compositeDataSet, compositeDataSet2, 2);
        DataSet[] firstElement = split.getFirstElement();
        double[] secondElement = split.getSecondElement();
        DataSet[] dataSetArr = firstElement;
        DataSet dataSet4 = compositeDataSet;
        DataSet dataSet5 = compositeDataSet2;
        NumericalPerformanceMeasureParameterSet numericalPerformanceMeasureParameterSet = new NumericalPerformanceMeasureParameterSet(new AucROC(), new AucPR());
        for (int i = 0; i < 0; i++) {
            GaussianNetwork gaussianNetwork = new GaussianNetwork(new int[dataSetArr[0].getElementLength()][0]);
            GenDisMixClassifier genDisMixClassifier = new GenDisMixClassifier(new GenDisMixClassifierParameterSet(alphabetContainer, dataSetArr[0].getElementLength(), (byte) 20, 1.0E-6d, 1.0E-6d, 1.0E-4d, false, OptimizableFunction.KindOfParameter.ZEROS, true, 16), (LogPrior) null, LearningPrinciple.MCL, gaussianNetwork, gaussianNetwork);
            genDisMixClassifier.train(dataSetArr);
            System.out.println(genDisMixClassifier.evaluate(numericalPerformanceMeasureParameterSet, true, dataSetArr[0], dataSetArr[1]));
            secondElement = updateVals(secondElement, genDisMixClassifier, dataSet5);
            dataSetArr = split(dataSet4, secondElement, 0.95d);
            if (dataSet4 == compositeDataSet) {
                dataSet4 = compositeDataSet2;
                dataSet = compositeDataSet;
            } else {
                dataSet4 = compositeDataSet;
                dataSet = compositeDataSet2;
            }
            dataSet5 = dataSet;
        }
        DataSet[] split2 = split(union, secondElement, 0.95d);
        GaussianNetwork gaussianNetwork2 = new GaussianNetwork(new int[split2[0].getElementLength()][0]);
        GenDisMixClassifier genDisMixClassifier2 = new GenDisMixClassifier(new GenDisMixClassifierParameterSet(alphabetContainer, split2[0].getElementLength(), (byte) 20, 1.0E-6d, 1.0E-6d, 1.0E-4d, false, OptimizableFunction.KindOfParameter.ZEROS, true, 16), (LogPrior) null, LearningPrinciple.MCL, gaussianNetwork2, gaussianNetwork2);
        genDisMixClassifier2.train(split2);
        System.out.println(genDisMixClassifier2.evaluate(numericalPerformanceMeasureParameterSet, true, dataSet2, dataSet3));
        System.out.println(genDisMixClassifier2.evaluate(numericalPerformanceMeasureParameterSet, true, new DataSet(alphabetContainer, new SparseStringExtractor(strArr[2], '>'), "\t"), new DataSet(alphabetContainer, new SparseStringExtractor(strArr[3], '>'), "\t")));
    }

    private static double[] updateVals(double[] dArr, GenDisMixClassifier genDisMixClassifier, DataSet dataSet) throws Exception {
        double[] zscore = ToolBox.zscore(dArr);
        double[] zscore2 = ToolBox.zscore(genDisMixClassifier.getScores(dataSet));
        for (int i = 0; i < zscore2.length; i++) {
            zscore2[i] = zscore2[i] + (2.0d * zscore[i]);
        }
        return zscore2;
    }

    private static Pair<DataSet[], double[]> split(DataSet dataSet, DataSet dataSet2, int i) throws EmptyDataSetException, WrongAlphabetException {
        double[] dArr = new double[dataSet.getNumberOfElements()];
        for (int i2 = 0; i2 < dataSet.getNumberOfElements(); i2++) {
            dArr[i2] = dataSet.getElementAt(i2).continuousVal(i);
        }
        return new Pair<>(split(dataSet2, dArr, 0.9d), dArr);
    }

    private static DataSet[] split(DataSet dataSet, double[] dArr, double d) throws EmptyDataSetException, WrongAlphabetException {
        double percentile = ToolBox.percentile(dArr, d);
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        for (int i = 0; i < dataSet.getNumberOfElements(); i++) {
            if (dArr[i] < percentile) {
                linkedList2.add(dataSet.getElementAt(i));
            } else {
                linkedList.add(dataSet.getElementAt(i));
            }
        }
        return new DataSet[]{new DataSet("", linkedList), new DataSet("", linkedList2)};
    }
}
