package projects.proteinmotifs;

import de.jstacs.classifiers.assessment.KFoldCrossValidation;
import de.jstacs.classifiers.assessment.KFoldCrossValidationAssessParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.CompositeLogPrior;
import de.jstacs.classifiers.performanceMeasures.NumericalPerformanceMeasureParameterSet;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.SparseStringExtractor;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.IndependentProductDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.SingleGaussianDiffSM;
import java.io.PrintWriter;

/* loaded from: input_file:projects/proteinmotifs/Classifier.class */
public class Classifier {
    public static void main(String[] strArr) throws Exception {
        DataSet dataSet = new DataSet(new AlphabetContainer(new ContinuousAlphabet()), new SparseStringExtractor(strArr[0], '#'), "\t");
        DataSet dataSet2 = new DataSet(new AlphabetContainer(new ContinuousAlphabet()), new SparseStringExtractor(strArr[1], '#'), "\t");
        DifferentiableStatisticalModel[] differentiableStatisticalModelArr = new DifferentiableStatisticalModel[dataSet.getElementLength()];
        differentiableStatisticalModelArr[0] = new SingleGaussianDiffSM(dataSet.getAlphabetContainer(), 64.0d, 90.0d, 4.938271604938272E-4d, 4.938271604938272E-4d, false);
        int i = 0 + 1;
        differentiableStatisticalModelArr[i] = new SingleGaussianDiffSM(dataSet.getAlphabetContainer(), 64.0d, -3.0d, 1000.0d, 500.0d, false);
        int i2 = i + 1;
        differentiableStatisticalModelArr[i2] = new SingleGaussianDiffSM(dataSet.getAlphabetContainer(), 64.0d, -3.0d, 1000.0d, 500.0d, false);
        int i3 = i2 + 1;
        for (int i4 = 0; i4 < 20; i4++) {
            differentiableStatisticalModelArr[i3 + i4] = new SingleGaussianDiffSM(dataSet.getAlphabetContainer(), 64.0d, 0.05d, 1000.0d, 500.0d, false);
        }
        int i5 = i3 + 20;
        for (int i6 = 0; i6 < 7; i6++) {
            differentiableStatisticalModelArr[i5 + i6] = new SingleGaussianDiffSM(dataSet.getAlphabetContainer(), 64.0d, 0.14285714285714285d, 1000.0d, 500.0d, false);
        }
        GenDisMixClassifier genDisMixClassifier = new GenDisMixClassifier(new GenDisMixClassifierParameterSet(dataSet.getAlphabetContainer(), dataSet.getElementLength(), (byte) 20, 1.0E-6d, 1.0E-6d, 1.0E-4d, false, OptimizableFunction.KindOfParameter.PLUGIN, true, 1), new CompositeLogPrior(), LearningPrinciple.MSP, new IndependentProductDiffSM(64.0d, true, differentiableStatisticalModelArr), new IndependentProductDiffSM(64.0d, true, (DifferentiableStatisticalModel[]) ArrayHandler.clone(differentiableStatisticalModelArr)));
        System.out.println(new KFoldCrossValidation(genDisMixClassifier).assess((NumericalPerformanceMeasureParameterSet) NumericalPerformanceMeasureParameterSet.createFilledParameters(true, 0.99d, 0.9d, 0.9d, 1.0d), new KFoldCrossValidationAssessParameterSet(DataSet.PartitionMethod.PARTITION_BY_NUMBER_OF_ELEMENTS, dataSet.getElementLength(), true, 5), dataSet, dataSet2));
        genDisMixClassifier.train(dataSet, dataSet2);
        System.out.println(genDisMixClassifier.evaluate((NumericalPerformanceMeasureParameterSet) NumericalPerformanceMeasureParameterSet.createFilledParameters(true, 0.999d, 0.9d, 0.9d, 1.0d), true, dataSet, dataSet2));
        double[] scores = genDisMixClassifier.getScores(dataSet);
        double[] scores2 = genDisMixClassifier.getScores(dataSet2);
        PrintWriter printWriter = new PrintWriter(String.valueOf(strArr[0]) + ".cl");
        for (double d : scores) {
            printWriter.println(d);
        }
        printWriter.close();
        PrintWriter printWriter2 = new PrintWriter(String.valueOf(strArr[1]) + ".cl");
        for (double d2 : scores2) {
            printWriter2.println(d2);
        }
        printWriter2.close();
    }
}
