/*
 * Decompiled with CFR 0.152.
 */
package projects.sigma;

import de.jstacs.DataType;
import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.DNAAlphabetContainer;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.SequenceAnnotationParser;
import de.jstacs.data.sequences.annotation.SimpleSequenceAnnotationParser;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.SparseStringExtractor;
import de.jstacs.parameters.FileParameter;
import de.jstacs.parameters.Parameter;
import de.jstacs.parameters.ParameterException;
import de.jstacs.parameters.SimpleParameter;
import de.jstacs.parameters.validation.NumberValidator;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.DataSetResult;
import de.jstacs.results.ListResult;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.PlotGeneratorResult;
import de.jstacs.results.Result;
import de.jstacs.results.ResultSet;
import de.jstacs.results.StorableResult;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.SilentEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.discrete.AbstractConditionalDiscreteEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.discrete.DiscreteEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.HMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.ViterbiParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.BasicHigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.tools.JstacsTool;
import de.jstacs.tools.ProgressUpdater;
import de.jstacs.tools.Protocol;
import de.jstacs.tools.ProtocolOutputStream;
import de.jstacs.tools.ToolParameterSet;
import de.jstacs.tools.ToolResult;
import de.jstacs.tools.ui.cli.CLI;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import de.jstacs.utils.SeqLogoPlotter;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.LinkedList;

public class SigmaIter
implements JstacsTool {
    public static void main(String[] args) throws Exception {
        CLI cli = new CLI(new boolean[]{true}, new SigmaIter());
        cli.run(args);
    }

    @Override
    public ToolParameterSet getToolParameters() {
        LinkedList<Parameter> pars = new LinkedList<Parameter>();
        pars.add(new FileParameter("Input sequences", "", "fasta,fas,fa", true));
        try {
            pars.add(new SimpleParameter(DataType.INT, "Start", "", true, 50));
            pars.add(new SimpleParameter(DataType.INT, "End", "", true, 70));
            pars.add(new SimpleParameter(DataType.INT, "Length", "", true, new NumberValidator<Integer>(3, 20), 10));
            pars.add(new SimpleParameter(DataType.INT, "Number of Starts", "", true, 500));
        }
        catch (ParameterException e) {
            e.printStackTrace();
        }
        return new ToolParameterSet(this.getShortName(), pars.toArray(new Parameter[0]));
    }

    @Override
    public ToolResult run(ToolParameterSet parameters, Protocol protocol, ProgressUpdater progress, int threads) throws Exception {
        DataSet data;
        SimpleSequenceAnnotationParser parser = new SimpleSequenceAnnotationParser();
        DataSet orig = data = new DataSet(DNAAlphabetContainer.SINGLETON, new SparseStringExtractor(new StringReader(((FileParameter)parameters.getParameterAt(0)).getFileContents().getContent()), '>', "", (SequenceAnnotationParser)parser));
        int start = (Integer)parameters.getParameterAt(1).getValue();
        int end = (Integer)parameters.getParameterAt(2).getValue();
        int length = (Integer)parameters.getParameterAt(3).getValue();
        int nStarts = (Integer)parameters.getParameterAt(4).getValue();
        int off = 1;
        LinkedList<Result> ress = new LinkedList<Result>();
        int[] idx = new int[data.getNumberOfElements()];
        int i = 0;
        while (i < idx.length) {
            idx[i] = i;
            ++i;
        }
        ResultSet[] rs = new ResultSet[data.getNumberOfElements()];
        while (data.getNumberOfElements() > 10) {
            HigherOrderHMM hmm = this.trainModel(data, start, end, length, 16.0, nStarts, protocol, threads);
            ress.add(new StorableResult("HMM" + off, "", hmm));
            IntermediateResult ir = SigmaIter.evaluateModel(data, hmm, idx, length, off, rs);
            if (ir.in.size() == 0) break;
            SigmaIter.addSeqLogos(ir.models.getFirstElement(), ress, off);
            DataSet inSet = new DataSet("", ir.in);
            DataSetResult dsr = new DataSetResult("Sequences for model " + off, "", inSet);
            dsr.setParser(parser);
            ress.add(dsr);
            if (ir.out.size() == 0) break;
            data = new DataSet("", ir.out);
            idx = ir.idx.toArray();
            ++off;
        }
        int i2 = 0;
        while (i2 < rs.length) {
            if (rs[i2] == null) {
                rs[i2] = new ResultSet(new Result[][]{{new NumericalResult("Index", "", i2), new NumericalResult("start", "", -1), new CategoricalResult("seq", "", "NA"), new NumericalResult("Score", "", Double.NEGATIVE_INFINITY), new NumericalResult("component", "", -1), new NumericalResult("LLR", "", Double.NEGATIVE_INFINITY), new CategoricalResult("annotation", "", (String)orig.getElementAt(i2).getAnnotation()[0].getResultAt(0).getValue())}});
            }
            ++i2;
        }
        ress.add(new ListResult("Predictions", "", null, rs));
        return new ToolResult("Result of " + this.getToolName(), this.getToolName(), null, new ResultSet(ress), parameters, this.getToolName(), new Date(System.currentTimeMillis()));
    }

    static IntermediateResult evaluateModel(DataSet data, HigherOrderHMM hmm, int[] idx, int length, int off, ResultSet[] rs) throws Exception {
        Pair<double[][], double[]> models = SigmaIter.getModels(hmm, length);
        LinkedList<Sequence> in = new LinkedList<Sequence>();
        IntList idx2 = new IntList();
        LinkedList<Sequence> out = new LinkedList<Sequence>();
        int i = 0;
        while (i < data.getNumberOfElements()) {
            Sequence seq = data.getElementAt(i);
            Pair<IntList, Double> pair2 = hmm.getViterbiPathFor(seq);
            String[] path = hmm.decodePath(pair2.getFirstElement());
            int seqStart = -1;
            int seqEnd = -1;
            int j = 0;
            while (j < path.length) {
                if (path[j].startsWith("M0")) {
                    seqStart = j;
                }
                if (path[j].startsWith("E")) {
                    seqEnd = j;
                    break;
                }
                ++j;
            }
            if (seqStart > -1) {
                Sequence sub = seq.getSubSequence(seqStart, seqEnd - seqStart);
                double llr = SigmaIter.getLLR(sub, models.getFirstElement(), models.getSecondElement());
                if (llr > Math.log(2.0)) {
                    in.add(seq);
                } else {
                    out.add(seq);
                    idx2.add(idx[i]);
                }
                rs[idx[i]] = new ResultSet(new Result[][]{{new NumericalResult("Index", "", idx[i]), new NumericalResult("start", "", seqStart), new CategoricalResult("seq", "", sub.toString()), new NumericalResult("Score", "", pair2.getSecondElement()), new NumericalResult("component", "", off), new NumericalResult("LLR", "", llr), new CategoricalResult("annotation", "", (String)seq.getAnnotation()[0].getResultAt(0).getValue())}});
            } else {
                System.out.println("else");
                out.add(seq);
                idx2.add(idx[i]);
                rs[idx[i]] = new ResultSet(new Result[][]{{new NumericalResult("Index", "", idx[i]), new NumericalResult("start", "", seqStart), new CategoricalResult("seq", "", "NA"), new NumericalResult("Score", "", Double.NEGATIVE_INFINITY), new NumericalResult("component", "", -1), new NumericalResult("LLR", "", Double.NEGATIVE_INFINITY), new CategoricalResult("annotation", "", (String)seq.getAnnotation()[0].getResultAt(0).getValue())}});
            }
            ++i;
        }
        return new IntermediateResult(idx2, in, out, models);
    }

    private static double getLLR(Sequence sub, double[][] pwm, double[] bgMod) {
        double score = 0.0;
        int i = 0;
        while (i < pwm.length) {
            score += pwm[i][sub.discreteVal(i)] - bgMod[sub.discreteVal(i)];
            ++i;
        }
        return score;
    }

    private static void addSeqLogos(double[][] logPWM, LinkedList<Result> ress, int off) throws CloneNotSupportedException {
        logPWM = (double[][])ArrayHandler.clone((Cloneable[])logPWM);
        int j = 0;
        while (j < logPWM.length) {
            Normalisation.logSumNormalisation(logPWM[j]);
            ++j;
        }
        ress.add(new PlotGeneratorResult("SeqLogo_" + off, "", new SeqLogoPlotter.SeqLogoPlotGenerator(logPWM, 100), true));
    }

    private static Pair<double[][], double[]> getModels(HigherOrderHMM hmm, int length) throws CloneNotSupportedException {
        Emission[] emission = hmm.getEmissions();
        double[][] logPWM = new double[length][];
        int i = 0;
        while (i < length) {
            AbstractConditionalDiscreteEmission temp = ((DiscreteEmission)emission[i + 1]).clone();
            temp.setParameterOffset(0);
            double[] pars = new double[temp.getNumberOfParameters()];
            temp.fillCurrentParameter(pars);
            logPWM[i] = (double[])pars.clone();
            ++i;
        }
        AbstractConditionalDiscreteEmission temp = ((DiscreteEmission)emission[0]).clone();
        temp.setParameterOffset(0);
        double[] hom = new double[temp.getNumberOfParameters()];
        temp.fillCurrentParameter(hom);
        return new Pair<double[][], double[]>(logPWM, hom);
    }

    private HigherOrderHMM trainModel(DataSet data, int start, int end, int len, double ess, int nStarts, Protocol protocol, int threads) throws Exception {
        int seqLen = data.getElementLength();
        DiscreteEmission insert = new DiscreteEmission((AlphabetContainer)DNAAlphabetContainer.SINGLETON, ess * (double)(seqLen - len) + ess * (double)len / 2.0);
        Emission[] emission = new DifferentiableEmission[1 + len + 1];
        emission[0] = insert;
        int i = 0;
        while (i < len) {
            emission[i + 1] = new DiscreteEmission((AlphabetContainer)DNAAlphabetContainer.SINGLETON, ess / 2.0);
            ++i;
        }
        emission[emission.length - 1] = new SilentEmission();
        int[] emissionIdx = new int[start - 1 + 1 + 2 * len + 1 + seqLen - end - 1 + 1];
        String[] name = new String[emissionIdx.length];
        ArrayList<TransitionElement> tes = new ArrayList<TransitionElement>();
        int k = 0;
        tes.add(new TransitionElement(null, new int[1], new double[]{ess}));
        int i2 = 0;
        while (i2 < start - 1) {
            emissionIdx[k] = 0;
            name[k] = "O" + k;
            tes.add(new TransitionElement(new int[]{k}, new int[]{k + 1}, new double[]{ess}));
            ++i2;
            ++k;
        }
        double fac = (double)(end - start + 1 - len) / 2.0;
        emissionIdx[k] = 0;
        name[k] = "B";
        tes.add(new TransitionElement(new int[]{k}, new int[]{k, k + 1, k + 1 + len}, new double[]{ess * fac, ess / 2.0, ess / 2.0}));
        ++k;
        int i3 = 0;
        while (i3 < len) {
            emissionIdx[k] = i3 + 1;
            name[k] = "M" + i3;
            if (i3 < len - 1) {
                tes.add(new TransitionElement(new int[]{k}, new int[]{k + 1}, new double[]{ess / 2.0}));
            } else {
                tes.add(new TransitionElement(new int[]{k}, new int[]{k + 1 + len}, new double[]{ess / 2.0}));
            }
            ++i3;
            ++k;
        }
        i3 = 0;
        while (i3 < len) {
            emissionIdx[k] = 0;
            name[k] = "A" + i3;
            tes.add(new TransitionElement(new int[]{k}, new int[]{k + 1}, new double[]{ess / 2.0}));
            ++i3;
            ++k;
        }
        emissionIdx[k] = 0;
        name[k] = "E";
        tes.add(new TransitionElement(new int[]{k}, new int[]{k, k + 1}, new double[]{ess * fac, ess}));
        ++k;
        i3 = end + 1;
        while (i3 < seqLen) {
            emissionIdx[k] = 0;
            name[k] = "O" + k;
            tes.add(new TransitionElement(new int[]{k}, new int[]{k + 1}, new double[]{ess}));
            ++i3;
            ++k;
        }
        name[k] = "F";
        emissionIdx[k] = emission.length - 1;
        boolean[] forward = new boolean[name.length];
        Arrays.fill(forward, true);
        ViterbiParameterSet trainingParameterSet = new ViterbiParameterSet(nStarts, new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-6), threads);
        HigherOrderHMM hmm = new HigherOrderHMM((HMMTrainingParameterSet)trainingParameterSet, name, emissionIdx, forward, emission, (BasicHigherOrderTransition.AbstractTransitionElement[])tes.toArray(new TransitionElement[0]));
        hmm.setOutputStream(new ProtocolOutputStream(protocol, true));
        hmm.train(data);
        return hmm;
    }

    @Override
    public String getToolName() {
        return "iter";
    }

    @Override
    public String getToolVersion() {
        return "0.1";
    }

    @Override
    public String getShortName() {
        return "iter";
    }

    @Override
    public String getDescription() {
        return "";
    }

    @Override
    public String getHelpText() {
        return "";
    }

    @Override
    public JstacsTool.ResultEntry[] getDefaultResultInfos() {
        return null;
    }

    @Override
    public ToolResult[] getTestCases(String path) {
        return null;
    }

    @Override
    public void clear() {
    }

    @Override
    public String[] getReferences() {
        return null;
    }

    public static class IntermediateResult {
        IntList idx;
        LinkedList<Sequence> in;
        LinkedList<Sequence> out;
        Pair<double[][], double[]> models;

        public IntermediateResult(IntList idx, LinkedList<Sequence> in, LinkedList<Sequence> out, Pair<double[][], double[]> models) {
            this.idx = idx;
            this.in = in;
            this.out = out;
            this.models = models;
        }
    }
}

