/*
 * Decompiled with CFR 0.152.
 */
package projects.taleningner;

import de.jstacs.classifiers.assessment.RepeatedHoldOutAssessParameterSet;
import de.jstacs.classifiers.assessment.RepeatedHoldOutExperiment;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.msp.MSPClassifier;
import de.jstacs.classifiers.performanceMeasures.CorrelationCoefficient;
import de.jstacs.classifiers.performanceMeasures.NumericalPerformanceMeasureParameterSet;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.Alphabet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.ArbitrarySequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.FileManager;
import de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.differentiable.IndependentProductDiffSS;
import de.jstacs.sequenceScores.differentiable.UniformDiffSS;
import de.jstacs.sequenceScores.differentiable.logistic.LogisticDiffSS;
import de.jstacs.sequenceScores.differentiable.logistic.ProductConstraint;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.DirichletDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.ExpGammaDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.Arrays;
import java.util.LinkedList;
import projects.dimont.Interpolation;
import projects.taleningner.DifferenceDiffSS;
import projects.taleningner.MonomerScoring;
import projects.taleningner.TALENDiffSM;

public class DimerTest {
    static double maxmaxdist = 0.0;
    static double maxmaxstretch = 0.0;

    private static Pair<DataSet, double[][]> getData(String path) throws Exception {
        System.out.println("path: " + path);
        String[] alph = new String[]{"NI", "NG", "NN", "NS", "N*", "ND", "NK", "NC", "NV", "NA", "NH", "HD", "HG", "HA", "H*", "HH", "HI", "HN", "S*", "SN", "SS", "IG", "YG", "NP", "NT", "IS"};
        AlphabetContainer alphabetsRVD = new AlphabetContainer((Alphabet)new DiscreteAlphabet(true, alph));
        BufferedReader read = new BufferedReader(new FileReader(path));
        String str = null;
        MonomerScoring sc = new MonomerScoring(alphabetsRVD);
        LinkedList<Sequence> seq = new LinkedList<Sequence>();
        DoubleList ws = new DoubleList();
        AlphabetContainer cont = new AlphabetContainer((Alphabet)new ContinuousAlphabet(true));
        while ((str = read.readLine()) != null) {
            double[] rights;
            double[] lefts;
            String[] parts = str.split("\t");
            Sequence left = Sequence.create(alphabetsRVD, parts[0], "-");
            Sequence right = Sequence.create(alphabetsRVD, parts[1], "-");
            double w = Double.parseDouble(parts[2]);
            double d = 1.0;
            if (!parts[3].trim().equals("NA")) {
                d = Double.parseDouble(parts[3]);
                try {
                    double opt = Double.parseDouble(parts[6]);
                    d /= opt;
                }
                catch (NumberFormatException e) {
                    d = Double.NaN;
                    d = 1.0;
                }
            }
            if (Math.max((lefts = sc.getValues(left))[4], (rights = sc.getValues(right))[4]) > maxmaxdist) {
                maxmaxdist = Math.max(lefts[4], rights[4]);
            }
            if (Math.max(lefts[5], rights[5]) > maxmaxstretch) {
                maxmaxstretch = Math.max(lefts[5], rights[5]);
            }
            double[] all = new double[lefts.length + rights.length + 1];
            System.arraycopy(lefts, 0, all, 0, lefts.length);
            System.arraycopy(rights, 0, all, lefts.length, rights.length);
            all[all.length - 1] = d;
            seq.add(new ArbitrarySequence(cont, all));
            ws.add(w);
        }
        double[] intensa = ws.toArray();
        DataSet contData = new DataSet("", seq);
        double[] weights = Interpolation.getWeight(contData, intensa, 0.5, Interpolation.PERCENTILE_LOGISTIC);
        double[] bgWeights = Interpolation.getBgWeight(weights);
        return new Pair<DataSet, double[][]>(contData, new double[][]{intensa, weights, bgWeights});
    }

    private static Pair<DataSet, double[][]> join(Pair<DataSet, double[][]>[] pairs) throws Exception {
        LinkedList<Sequence> seqs = new LinkedList<Sequence>();
        DoubleList intens = new DoubleList();
        DoubleList w = new DoubleList();
        DoubleList bgw = new DoubleList();
        int i = 0;
        while (i < pairs.length) {
            DataSet data = pairs[i].getFirstElement();
            double[][] weights = pairs[i].getSecondElement();
            int j = 0;
            while (j < data.getNumberOfElements()) {
                seqs.add(data.getElementAt(j));
                intens.add(weights[0][j]);
                w.add(weights[1][j]);
                bgw.add(weights[2][j]);
                ++j;
            }
            ++i;
        }
        return new Pair<DataSet, double[][]>(new DataSet("", seqs), new double[][]{intens.toArray(), w.toArray(), bgw.toArray()});
    }

    private static Pair<DataSet, double[][]> getData(String[] args) throws Exception {
        Pair[] pairs = new Pair[args.length - 1];
        int i = 0;
        while (i < pairs.length) {
            pairs[i] = DimerTest.getData(args[i]);
            ++i;
        }
        return DimerTest.join(pairs);
    }

    public static void main(String[] args) throws Exception {
        Pair<DataSet, double[][]> pair = DimerTest.getData(args);
        System.out.println("maxmax: " + maxmaxdist + " " + maxmaxstretch);
        DataSet contData = pair.getFirstElement();
        double[] intensa = pair.getSecondElement()[0];
        double[] weights = pair.getSecondElement()[1];
        double[] bgWeights = pair.getSecondElement()[2];
        int monomerScoreLength = (contData.getElementLength() - 1) / 2;
        LinkedList<AbstractDifferentiableSequenceScore> sslist = new LinkedList<AbstractDifferentiableSequenceScore>();
        LinkedList<ProductConstraint> constraints = new LinkedList<ProductConstraint>();
        constraints.add(new ProductConstraint(new int[0]));
        constraints.add(new ProductConstraint(0));
        constraints.add(new ProductConstraint(1));
        constraints.add(new ProductConstraint(2));
        constraints.add(new ProductConstraint(0, 2));
        constraints.add(new ProductConstraint(1, 2));
        constraints.add(new ProductConstraint(0, 1));
        constraints.add(new ProductConstraint(0, 1, 2));
        sslist.add(new ExpGammaDiffSM(contData.getAlphabetContainer(), 1, 4.0, new double[]{10.0}, new double[]{9.0}, true));
        sslist.add(new DirichletDiffSM(contData.getAlphabetContainer(), 2, new double[]{1.0, 1.0, 1.0}, 4.0, 1));
        sslist.add(new ExpGammaDiffSM(contData.getAlphabetContainer(), 1, 4.0, new double[]{10.0}, new double[]{9.0}, true));
        sslist.add(new ExpGammaDiffSM(contData.getAlphabetContainer(), 1, 4.0, new double[]{10.0}, new double[]{9.0}, true));
        sslist.add(new ExpGammaDiffSM(contData.getAlphabetContainer(), 1, 4.0, new double[]{10.0}, new double[]{9.0}, true));
        sslist.add(new DirichletDiffSM(contData.getAlphabetContainer(), 3, new double[]{1.0, 1.0, 1.0, 1.0}, 4.0, 1));
        sslist.add(new ExpGammaDiffSM(contData.getAlphabetContainer(), 1, 4.0, new double[]{10.0}, new double[]{9.0}, true));
        LogisticDiffSS ss = new LogisticDiffSS(contData.getAlphabetContainer(), 3, constraints.toArray(new ProductConstraint[0]));
        sslist.add(new UniformDiffSS(contData.getAlphabetContainer(), 3));
        DifferentiableSequenceScore[] scores = new DifferentiableSequenceScore[sslist.size()];
        int i = 0;
        while (i < sslist.size()) {
            DifferentiableSequenceScore el = (DifferentiableSequenceScore)sslist.get(i);
            scores[i] = el instanceof ExpGammaDiffSM ? new DifferenceDiffSS(el, el.clone()) : (DifferentiableSequenceScore)sslist.get(i);
            ++i;
        }
        IndependentProductDiffSS fg = new IndependentProductDiffSS(true, scores);
        AbstractDifferentiableSequenceScore spacer = new ExpGammaDiffSM(contData.getAlphabetContainer(), 1, 4.0, new double[]{10.0}, new double[]{9.0}, true);
        DifferentiableSequenceScore spacer2 = spacer.clone();
        spacer = new DifferenceDiffSS(spacer, spacer2);
        TALENDiffSM score = new TALENDiffSM(fg, spacer, monomerScoreLength, true);
        UniformDiffSS bg = new UniformDiffSS(contData.getAlphabetContainer(), contData.getElementLength());
        GenDisMixClassifierParameterSet params = new GenDisMixClassifierParameterSet(score.getAlphabetContainer(), score.getLength(), 20, 1.0E-12, 1.0E-12, 1.0E-4, false, OptimizableFunction.KindOfParameter.PLUGIN, true, 1);
        MSPClassifier cl = new MSPClassifier(params, score, bg);
        NumericalPerformanceMeasureParameterSet mp = new NumericalPerformanceMeasureParameterSet(new CorrelationCoefficient(CorrelationCoefficient.Method.SPEARMAN, true));
        RepeatedHoldOutAssessParameterSet assessPS = new RepeatedHoldOutAssessParameterSet(DataSet.PartitionMethod.PARTITION_BY_NUMBER_OF_ELEMENTS, contData.getElementLength(), true, 2, new double[]{0.3, 0.3});
        RepeatedHoldOutExperiment exp = new RepeatedHoldOutExperiment(cl);
        System.out.println(exp.assess(mp, assessPS, null, new DataSet[]{contData, contData}, new double[][]{weights, bgWeights}));
        cl.train(new DataSet[]{contData, contData}, new double[][]{weights, bgWeights});
        FileManager.writeFile("/Users/dev/Desktop/TAL-Chips/Designer/model/model_test.xml", (CharSequence)cl.toXML());
        System.out.println(cl);
        double[] pred = cl.getScores(contData);
        System.out.println(ToolBox.spearmanCorrelation(weights, pred));
        Pair<DataSet, double[][]> test = DimerTest.getData(args[args.length - 1]);
        pred = cl.getScores(test.getFirstElement());
        System.out.println(Arrays.toString(pred));
        System.out.println(ToolBox.spearmanCorrelation(test.getSecondElement()[0], pred));
        intensa = test.getSecondElement()[0];
        DoubleList pos = new DoubleList();
        DoubleList neg = new DoubleList();
        int i2 = 0;
        while (i2 < intensa.length) {
            if (intensa[i2] > 10.0) {
                pos.add(pred[i2]);
            } else {
                neg.add(pred[i2]);
            }
            ++i2;
        }
        System.out.println(pos);
        System.out.println(neg);
    }
}

