package projects.taleningner;

import cern.colt.matrix.impl.AbstractFormatter;
import de.jstacs.classifiers.assessment.RepeatedHoldOutAssessParameterSet;
import de.jstacs.classifiers.assessment.RepeatedHoldOutExperiment;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.msp.MSPClassifier;
import de.jstacs.classifiers.performanceMeasures.CorrelationCoefficient;
import de.jstacs.classifiers.performanceMeasures.NumericalPerformanceMeasureParameterSet;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.ArbitrarySequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.FileManager;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.differentiable.IndependentProductDiffSS;
import de.jstacs.sequenceScores.differentiable.UniformDiffSS;
import de.jstacs.sequenceScores.differentiable.logistic.LogisticConstraint;
import de.jstacs.sequenceScores.differentiable.logistic.LogisticDiffSS;
import de.jstacs.sequenceScores.differentiable.logistic.ProductConstraint;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.DirichletDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.ExpGammaDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.Arrays;
import java.util.LinkedList;
import org.biojava.bio.program.tagvalue.TagValueParser;
import projects.dimont.Interpolation;

/* loaded from: input_file:projects/taleningner/DimerTest.class */
public class DimerTest {
    static double maxmaxdist = 0.0d;
    static double maxmaxstretch = 0.0d;

    /* JADX WARN: Multi-variable type inference failed */
    private static Pair<DataSet, double[][]> getData(String str) throws Exception {
        System.out.println("path: " + str);
        AlphabetContainer alphabetContainer = new AlphabetContainer(new DiscreteAlphabet(true, "NI", "NG", "NN", "NS", "N*", "ND", "NK", "NC", "NV", "NA", "NH", "HD", "HG", "HA", "H*", "HH", "HI", "HN", "S*", "SN", "SS", "IG", "YG", "NP", "NT", "IS"));
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        MonomerScoring monomerScoring = new MonomerScoring(alphabetContainer);
        LinkedList linkedList = new LinkedList();
        DoubleList doubleList = new DoubleList();
        AlphabetContainer alphabetContainer2 = new AlphabetContainer(new ContinuousAlphabet(true));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                double[] array = doubleList.toArray();
                DataSet dataSet = new DataSet(TagValueParser.EMPTY_LINE_EOR, linkedList);
                double[] weight = Interpolation.getWeight(dataSet, array, 0.5d, Interpolation.PERCENTILE_LOGISTIC);
                return new Pair<>(dataSet, new double[]{array, weight, Interpolation.getBgWeight(weight)});
            }
            String[] split = readLine.split("\t");
            Sequence create = Sequence.create(alphabetContainer, split[0], "-");
            Sequence create2 = Sequence.create(alphabetContainer, split[1], "-");
            double parseDouble = Double.parseDouble(split[2]);
            double d = 1.0d;
            if (!split[3].trim().equals("NA")) {
                try {
                    d = Double.parseDouble(split[3]) / Double.parseDouble(split[6]);
                } catch (NumberFormatException e) {
                    d = 1.0d;
                }
            }
            double[] values = monomerScoring.getValues(create);
            double[] values2 = monomerScoring.getValues(create2);
            if (Math.max(values[4], values2[4]) > maxmaxdist) {
                maxmaxdist = Math.max(values[4], values2[4]);
            }
            if (Math.max(values[5], values2[5]) > maxmaxstretch) {
                maxmaxstretch = Math.max(values[5], values2[5]);
            }
            double[] dArr = new double[values.length + values2.length + 1];
            System.arraycopy(values, 0, dArr, 0, values.length);
            System.arraycopy(values2, 0, dArr, values.length, values2.length);
            dArr[dArr.length - 1] = d;
            linkedList.add(new ArbitrarySequence(alphabetContainer2, dArr));
            doubleList.add(parseDouble);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static Pair<DataSet, double[][]> join(Pair<DataSet, double[][]>[] pairArr) throws Exception {
        LinkedList linkedList = new LinkedList();
        DoubleList doubleList = new DoubleList();
        DoubleList doubleList2 = new DoubleList();
        DoubleList doubleList3 = new DoubleList();
        for (int i = 0; i < pairArr.length; i++) {
            DataSet firstElement = pairArr[i].getFirstElement();
            double[][] secondElement = pairArr[i].getSecondElement();
            for (int i2 = 0; i2 < firstElement.getNumberOfElements(); i2++) {
                linkedList.add(firstElement.getElementAt(i2));
                doubleList.add(secondElement[0][i2]);
                doubleList2.add(secondElement[1][i2]);
                doubleList3.add(secondElement[2][i2]);
            }
        }
        return new Pair<>(new DataSet(TagValueParser.EMPTY_LINE_EOR, linkedList), new double[]{doubleList.toArray(), doubleList2.toArray(), doubleList3.toArray()});
    }

    private static Pair<DataSet, double[][]> getData(String[] strArr) throws Exception {
        Pair[] pairArr = new Pair[strArr.length - 1];
        for (int i = 0; i < pairArr.length; i++) {
            pairArr[i] = getData(strArr[i]);
        }
        return join(pairArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r2v41, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r6v43, types: [double[], double[][]] */
    public static void main(String[] strArr) throws Exception {
        Pair<DataSet, double[][]> data = getData(strArr);
        System.out.println("maxmax: " + maxmaxdist + AbstractFormatter.DEFAULT_COLUMN_SEPARATOR + maxmaxstretch);
        DataSet firstElement = data.getFirstElement();
        double[] dArr = data.getSecondElement()[0];
        double[] dArr2 = data.getSecondElement()[1];
        double[] dArr3 = data.getSecondElement()[2];
        int elementLength = (firstElement.getElementLength() - 1) / 2;
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        linkedList2.add(new ProductConstraint(new int[0]));
        linkedList2.add(new ProductConstraint(0));
        linkedList2.add(new ProductConstraint(1));
        linkedList2.add(new ProductConstraint(2));
        linkedList2.add(new ProductConstraint(0, 2));
        linkedList2.add(new ProductConstraint(1, 2));
        linkedList2.add(new ProductConstraint(0, 1));
        linkedList2.add(new ProductConstraint(0, 1, 2));
        linkedList.add(new ExpGammaDiffSM(firstElement.getAlphabetContainer(), 1, 4.0d, new double[]{10.0d}, new double[]{9.0d}, true));
        linkedList.add(new DirichletDiffSM(firstElement.getAlphabetContainer(), 2, new double[]{1.0d, 1.0d, 1.0d}, 4.0d, 1));
        linkedList.add(new ExpGammaDiffSM(firstElement.getAlphabetContainer(), 1, 4.0d, new double[]{10.0d}, new double[]{9.0d}, true));
        linkedList.add(new ExpGammaDiffSM(firstElement.getAlphabetContainer(), 1, 4.0d, new double[]{10.0d}, new double[]{9.0d}, true));
        linkedList.add(new ExpGammaDiffSM(firstElement.getAlphabetContainer(), 1, 4.0d, new double[]{10.0d}, new double[]{9.0d}, true));
        linkedList.add(new DirichletDiffSM(firstElement.getAlphabetContainer(), 3, new double[]{1.0d, 1.0d, 1.0d, 1.0d}, 4.0d, 1));
        linkedList.add(new ExpGammaDiffSM(firstElement.getAlphabetContainer(), 1, 4.0d, new double[]{10.0d}, new double[]{9.0d}, true));
        new LogisticDiffSS(firstElement.getAlphabetContainer(), 3, (LogisticConstraint[]) linkedList2.toArray(new ProductConstraint[0]));
        linkedList.add(new UniformDiffSS(firstElement.getAlphabetContainer(), 3));
        DifferentiableSequenceScore[] differentiableSequenceScoreArr = new DifferentiableSequenceScore[linkedList.size()];
        for (int i = 0; i < linkedList.size(); i++) {
            DifferentiableSequenceScore differentiableSequenceScore = (DifferentiableSequenceScore) linkedList.get(i);
            if (differentiableSequenceScore instanceof ExpGammaDiffSM) {
                differentiableSequenceScoreArr[i] = new DifferenceDiffSS(differentiableSequenceScore, differentiableSequenceScore.mo116clone());
            } else {
                differentiableSequenceScoreArr[i] = (DifferentiableSequenceScore) linkedList.get(i);
            }
        }
        IndependentProductDiffSS independentProductDiffSS = new IndependentProductDiffSS(true, differentiableSequenceScoreArr);
        ExpGammaDiffSM expGammaDiffSM = new ExpGammaDiffSM(firstElement.getAlphabetContainer(), 1, 4.0d, new double[]{10.0d}, new double[]{9.0d}, true);
        TALENDiffSM tALENDiffSM = new TALENDiffSM(independentProductDiffSS, new DifferenceDiffSS(expGammaDiffSM, expGammaDiffSM.mo116clone()), elementLength, true);
        MSPClassifier mSPClassifier = new MSPClassifier(new GenDisMixClassifierParameterSet(tALENDiffSM.getAlphabetContainer(), tALENDiffSM.getLength(), (byte) 20, 1.0E-12d, 1.0E-12d, 1.0E-4d, false, OptimizableFunction.KindOfParameter.PLUGIN, true, 1), tALENDiffSM, new UniformDiffSS(firstElement.getAlphabetContainer(), firstElement.getElementLength()));
        System.out.println(new RepeatedHoldOutExperiment(mSPClassifier).assess(new NumericalPerformanceMeasureParameterSet(new CorrelationCoefficient(CorrelationCoefficient.Method.SPEARMAN, true)), new RepeatedHoldOutAssessParameterSet(DataSet.PartitionMethod.PARTITION_BY_NUMBER_OF_ELEMENTS, firstElement.getElementLength(), true, 2, new double[]{0.3d, 0.3d}), null, new DataSet[]{firstElement, firstElement}, new double[]{dArr2, dArr3}));
        mSPClassifier.train(new DataSet[]{firstElement, firstElement}, new double[]{dArr2, dArr3});
        FileManager.writeFile("/Users/dev/Desktop/TAL-Chips/Designer/model/model_test.xml", mSPClassifier.toXML());
        System.out.println(mSPClassifier);
        System.out.println(ToolBox.spearmanCorrelation(dArr2, mSPClassifier.getScores(firstElement)));
        Pair<DataSet, double[][]> data2 = getData(strArr[strArr.length - 1]);
        double[] scores = mSPClassifier.getScores(data2.getFirstElement());
        System.out.println(Arrays.toString(scores));
        System.out.println(ToolBox.spearmanCorrelation(data2.getSecondElement()[0], scores));
        double[] dArr4 = data2.getSecondElement()[0];
        DoubleList doubleList = new DoubleList();
        DoubleList doubleList2 = new DoubleList();
        for (int i2 = 0; i2 < dArr4.length; i2++) {
            if (dArr4[i2] > 10.0d) {
                doubleList.add(scores[i2]);
            } else {
                doubleList2.add(scores[i2]);
            }
        }
        System.out.println(doubleList);
        System.out.println(doubleList2);
    }
}
