/*
 * Decompiled with CFR 0.152.
 */
package projects.talmodel;

import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DNADataSet;
import de.jstacs.data.alphabets.DNAAlphabetContainer;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.DifferentiableHigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.CodonEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DummyEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.SilentEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.discrete.DiscreteEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.BaumWelchParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.HMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.NumericalHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.BasicHigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Locale;

public class TALModelFactory {
    static double ess;
    private static double transEss;
    private static double codonEss;
    private static DummyEmission dummy;
    private static SilentEmission silent;
    private static DiscreteEmission insert1;
    private static DiscreteEmission insert2;
    private static DiscreteEmission insert3;
    private static DiscreteEmission insertD;

    static {
        transEss = ess = 4.0;
        codonEss = ess;
        dummy = new DummyEmission(DNAAlphabetContainer.SINGLETON);
        silent = new SilentEmission();
        insert1 = new DiscreteEmission((AlphabetContainer)DNAAlphabetContainer.SINGLETON, ess);
        insert2 = new DiscreteEmission((AlphabetContainer)DNAAlphabetContainer.SINGLETON, ess);
        insert3 = new DiscreteEmission((AlphabetContainer)DNAAlphabetContainer.SINGLETON, ess);
        insertD = new DiscreteEmission((AlphabetContainer)DNAAlphabetContainer.SINGLETON, ess);
    }

    public static int getMatchModule(LinkedList<Emission> emissions, LinkedList<Integer> emissionIdx, LinkedList<TransitionElement> trans, LinkedList<String> names, int layerIndex, int startIdx, Sequence codonPrototype) {
        emissionIdx.add(0);
        emissionIdx.add(0);
        emissionIdx.add(emissions.size());
        names.add("Du" + layerIndex + ".1");
        names.add("Du" + layerIndex + ".2");
        names.add("Co" + layerIndex);
        emissions.add(new CodonEmission(true, codonEss, codonPrototype, 0.5));
        trans.add(new TransitionElement(new int[]{startIdx + 1}, new int[]{startIdx + 2}, new double[]{transEss}));
        trans.add(new TransitionElement(new int[]{startIdx + 2}, new int[]{startIdx + 3}, new double[]{transEss}));
        return startIdx + 3;
    }

    public static int getInsertModule(LinkedList<Emission> emissions, LinkedList<Integer> emissionIdx, LinkedList<TransitionElement> trans, LinkedList<String> names, int layerIndex, int startIdx) {
        emissionIdx.add(2);
        emissionIdx.add(3);
        emissionIdx.add(4);
        emissionIdx.add(1);
        names.add("I" + layerIndex + ".1");
        names.add("I" + layerIndex + ".2");
        names.add("I" + layerIndex + ".3");
        names.add("I" + layerIndex + ".D");
        trans.add(new TransitionElement(new int[]{startIdx + 1}, new int[]{startIdx + 2, startIdx + 4}, new double[]{transEss / 2.0, transEss / 2.0}));
        trans.add(new TransitionElement(new int[]{startIdx + 2}, new int[]{startIdx + 3, startIdx + 4}, new double[]{transEss / 2.0, transEss / 2.0}));
        trans.add(new TransitionElement(new int[]{startIdx + 3}, new int[]{startIdx + 4, startIdx + 1}, new double[]{transEss / 2.0, transEss / 2.0}));
        return startIdx + 4;
    }

    public static int getDeleteModule(LinkedList<Emission> emissions, LinkedList<Integer> emissionIdx, LinkedList<TransitionElement> trans, LinkedList<String> names, int layerIndex, int startIdx) {
        emissionIdx.add(1);
        emissionIdx.add(5);
        emissionIdx.add(5);
        emissionIdx.add(1);
        names.add("D" + layerIndex + ".D1");
        names.add("D" + layerIndex + ".I1");
        names.add("D" + layerIndex + ".I2");
        names.add("D" + layerIndex + ".D2");
        trans.add(new TransitionElement(new int[]{startIdx + 1}, new int[]{startIdx + 2, startIdx + 4}, new double[]{transEss / 2.0, transEss / 2.0}));
        trans.add(new TransitionElement(new int[]{startIdx + 2}, new int[]{startIdx + 3, startIdx + 4}, new double[]{transEss / 2.0, transEss / 2.0}));
        trans.add(new TransitionElement(new int[]{startIdx + 3}, new int[]{startIdx + 4}, new double[]{transEss}));
        return startIdx + 4;
    }

    public static int[] getOneLayer(LinkedList<Emission> emissions, LinkedList<Integer> emissionIdx, LinkedList<TransitionElement> trans, LinkedList<String> names, int layerIndex, int startIdx, int prevDel, int prevMatch, boolean withDelete, Sequence codonPrototype) {
        int lastDelete = -1;
        int firstDelete = -1;
        if (withDelete) {
            firstDelete = startIdx + 1;
            lastDelete = startIdx = TALModelFactory.getDeleteModule(emissions, emissionIdx, trans, names, layerIndex, startIdx);
        }
        int firstInsert = startIdx + 1;
        int lastInsert = startIdx = TALModelFactory.getInsertModule(emissions, emissionIdx, trans, names, layerIndex, startIdx);
        int firstMatch = startIdx + 1;
        int lastMatch = startIdx = TALModelFactory.getMatchModule(emissions, emissionIdx, trans, names, layerIndex, startIdx, codonPrototype);
        if (prevDel >= 0) {
            if (withDelete) {
                trans.add(new TransitionElement(new int[]{prevDel}, new int[]{firstDelete, firstMatch}, new double[]{transEss / 2.0, transEss / 2.0}));
            } else {
                trans.add(new TransitionElement(new int[]{prevDel}, new int[]{firstMatch}, new double[]{transEss}));
            }
        }
        if (withDelete) {
            trans.add(new TransitionElement(new int[]{prevMatch}, new int[]{firstDelete, firstMatch, firstInsert}, new double[]{transEss / 3.0, transEss / 3.0, transEss / 3.0}));
        } else {
            trans.add(new TransitionElement(new int[]{prevMatch}, new int[]{firstMatch, firstInsert}, new double[]{transEss / 2.0, transEss / 2.0}));
        }
        trans.add(new TransitionElement(new int[]{lastInsert}, new int[]{firstMatch}, new double[]{transEss}));
        return new int[]{lastDelete, lastMatch};
    }

    public static int[] getModule(int layers, LinkedList<Emission> emissions, LinkedList<Integer> emissionIdx, LinkedList<TransitionElement> trans, LinkedList<String> names, int startLayer, int startIdx, int escapeLayer, Sequence repeatPrototype, boolean[] usePrototype) {
        int[] last = TALModelFactory.getOneLayer(emissions, emissionIdx, trans, names, startLayer + 1, startIdx, -1, startIdx, true, repeatPrototype == null || !usePrototype[0] ? null : repeatPrototype.getSubSequence(0, 3));
        int[] last2 = new int[]{-1, -1};
        int i = 1;
        while (i < layers) {
            startIdx = last[1];
            last = TALModelFactory.getOneLayer(emissions, emissionIdx, trans, names, startLayer + 1 + i, startIdx, last[0], last[1], i < layers - 1, repeatPrototype == null || !usePrototype[i] ? null : repeatPrototype.getSubSequence(i * 3, 3));
            if (i + 1 == escapeLayer) {
                last2 = (int[])last.clone();
            }
            ++i;
        }
        return new int[]{last[0], last[1], last2[0], last2[1]};
    }

    public static HigherOrderHMM createHMM(int layersStart, int layersRepeat, int layersEnd, int layerEscapeCycle, Sequence repeatPrototype, boolean[] usePrototype, int threads, boolean bw) throws Exception {
        LinkedList<Emission> emissions = new LinkedList<Emission>();
        LinkedList<Integer> emissionIdx = new LinkedList<Integer>();
        LinkedList<TransitionElement> trans = new LinkedList<TransitionElement>();
        LinkedList<String> names = new LinkedList<String>();
        emissions.add(dummy);
        emissions.add(silent);
        emissions.add(insert1);
        emissions.add(insert2);
        emissions.add(insert3);
        emissions.add(insertD);
        emissionIdx.add(1);
        names.add("S");
        trans.add(new TransitionElement(null, new int[1], new double[]{transEss}));
        int[] last = new int[]{-1, -1};
        if (layersStart > 0) {
            last = TALModelFactory.getModule(layersStart, emissions, emissionIdx, trans, names, 0, 0, -1, null, null);
            System.out.println("T1: " + (last[1] + 1));
            emissionIdx.add(1);
            names.add("T1");
            trans.add(new TransitionElement(new int[]{last[1]}, new int[]{last[1] + 1}, new double[]{transEss}));
        }
        if (layersRepeat > 0) {
            int ret = last[1] + 1;
            System.out.println("ret: " + ret);
            last = TALModelFactory.getModule(layersRepeat, emissions, emissionIdx, trans, names, layersStart, last[1] + 1, layerEscapeCycle, repeatPrototype, usePrototype);
            System.out.println("T2: " + (last[1] + 1));
            emissionIdx.add(1);
            names.add("T2");
            int i = 1;
            while (i < trans.size()) {
                TransitionElement te = trans.get(i);
                int idx = te.getLastContextState();
                if (idx == last[3]) {
                    int[] newChildren = new int[te.getNumberOfChildren() + 1];
                    int j = 0;
                    while (j < te.getNumberOfChildren()) {
                        newChildren[j] = te.getChild(j);
                        ++j;
                    }
                    newChildren[newChildren.length - 1] = last[1] + 1;
                    double[] hyper = new double[newChildren.length];
                    Arrays.fill(hyper, transEss / (double)hyper.length);
                    te = new TransitionElement(new int[]{idx}, newChildren, hyper);
                    trans.remove(i);
                    trans.add(i, te);
                }
                ++i;
            }
            trans.add(new TransitionElement(new int[]{last[1]}, new int[]{ret}, new double[]{transEss}));
        }
        if (layersEnd > 0) {
            last = TALModelFactory.getModule(layersEnd, emissions, emissionIdx, trans, names, layersStart + layersRepeat, last[1] + 1, -1, null, null);
            emissionIdx.add(1);
            names.add("E");
            trans.add(new TransitionElement(new int[]{last[1]}, new int[]{last[1] + 1}, new double[]{transEss}));
        }
        boolean[] forward = new boolean[emissionIdx.size()];
        Arrays.fill(forward, true);
        int[] emissionIdxAr = new int[emissionIdx.size()];
        int i = 0;
        while (i < emissionIdxAr.length) {
            emissionIdxAr[i] = emissionIdx.get(i);
            ++i;
        }
        if (bw) {
            BaumWelchParameterSet trainingParameterSet = new BaumWelchParameterSet(10, new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-6), threads);
            HigherOrderHMM hmm = new HigherOrderHMM((HMMTrainingParameterSet)trainingParameterSet, names.toArray(new String[0]), emissionIdxAr, forward, emissions.toArray(new Emission[0]), (BasicHigherOrderTransition.AbstractTransitionElement[])trans.toArray(new TransitionElement[0]));
            return hmm;
        }
        NumericalHMMTrainingParameterSet trainingParameterSet = new NumericalHMMTrainingParameterSet(30, new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-6), threads, 18, 1.0E-6, 1.0E-6);
        DifferentiableHigherOrderHMM hmm = new DifferentiableHigherOrderHMM(trainingParameterSet, names.toArray(new String[0]), emissionIdxAr, forward, ArrayHandler.cast(DifferentiableEmission.class, emissions.toArray(new Emission[0])), true, ess, trans.toArray(new TransitionElement[0]));
        return hmm;
    }

    public static void main(String[] args) throws Exception {
        Sequence repeatPrototype = Sequence.create(DNAAlphabetContainer.SINGLETON, "CTGACCCCGGACCAGGTCGTGGCCATTGCCAGCAATAACGGCGGCAAGCAGGCGCTGGAGACGGTGCAGCGGCTGTTGCCGGTGCTGTGCCAGGACCATGGC");
        boolean[] usePrototype = new boolean[repeatPrototype.getLength() / 3];
        Arrays.fill(usePrototype, true);
        usePrototype[12] = false;
        usePrototype[11] = false;
        DNADataSet ds = new DNADataSet(args[0]);
        System.out.println(ds.getAverageElementLength());
        HigherOrderHMM hmm = TALModelFactory.createHMM((int)(ds.getAverageElementLength() / 3.0), 0, 0, 20, repeatPrototype, usePrototype, Integer.parseInt(args[2]), Boolean.parseBoolean(args[3]));
        System.out.println(ds.getNumberOfElements());
        hmm.train(ds);
        System.out.println(hmm.getGraphvizRepresentation(DecimalFormat.getInstance(Locale.ENGLISH), false));
        PrintWriter wr = new PrintWriter(args[1]);
        wr.println(hmm.toXML());
        wr.close();
    }

    public static void mainRepeats(String[] args) throws Exception {
        Sequence repeatPrototype = Sequence.create(DNAAlphabetContainer.SINGLETON, "CTGACCCCGGACCAGGTCGTGGCCATTGCCAGCAATAACGGCGGCAAGCAGGCGCTGGAGACGGTGCAGCGGCTGTTGCCGGTGCTGTGCCAGGACCATGGC");
        boolean[] usePrototype = new boolean[repeatPrototype.getLength() / 3];
        Arrays.fill(usePrototype, true);
        usePrototype[12] = false;
        usePrototype[11] = false;
        HigherOrderHMM hmm = TALModelFactory.createHMM(0, 34, 0, 20, repeatPrototype, usePrototype, Integer.parseInt(args[2]), Boolean.parseBoolean(args[3]));
        DNADataSet ds = new DNADataSet(args[0]);
        System.out.println(ds.getNumberOfElements());
        hmm.train(ds);
        System.out.println(hmm.getGraphvizRepresentation(DecimalFormat.getInstance(Locale.ENGLISH), false));
        PrintWriter wr = new PrintWriter(args[1]);
        wr.println(hmm.toXML());
        wr.close();
    }
}

