package projects.dream2016;

import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.ArbitrarySequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.StringExtractor;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.DifferentiableHigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.SilentEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.continuous.GaussianEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.ViterbiParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.utils.IntList;
import de.jstacs.utils.ToolBox;
import java.io.File;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.LinkedList;
import org.apache.batik.dom.events.DOMKeyboardEvent;
import org.apache.batik.dom.svg.SVGPathSegConstants;
import org.apache.batik.util.SVGConstants;
import org.apache.xmlgraphics.image.loader.spi.ImagePreloader;

/* loaded from: input_file:projects/dream2016/DNaseHMM.class */
public class DNaseHMM {
    private static DataSet medianSmoothing(DataSet dataSet, int i) throws Exception {
        LinkedList linkedList = new LinkedList();
        for (int i2 = 0; i2 < dataSet.getNumberOfElements(); i2++) {
            ArbitrarySequence arbitrarySequence = (ArbitrarySequence) dataSet.getElementAt(i2);
            double[] dArr = new double[i];
            double[] dArr2 = new double[arbitrarySequence.getLength() / i];
            for (int i3 = 0; i3 < dArr2.length; i3++) {
                for (int i4 = 0; i4 < dArr.length; i4++) {
                    dArr[i4] = arbitrarySequence.continuousVal((i3 * i) + i4);
                }
                dArr2[i3] = ToolBox.median(dArr);
            }
            double median = ToolBox.median(dArr);
            if (median == 0.0d) {
                median = 1.0d;
            }
            for (int i5 = 0; i5 < dArr2.length; i5++) {
                int i6 = i5;
                dArr2[i6] = dArr2[i6] / median;
            }
            linkedList.add(new ArbitrarySequence(arbitrarySequence.getAlphabetContainer(), dArr2));
        }
        return new DataSet("", linkedList);
    }

    private static double[] getMean(DataSet dataSet) {
        double[] dArr = new double[dataSet.getElementLength()];
        for (int i = 0; i < dataSet.getNumberOfElements(); i++) {
            Sequence elementAt = dataSet.getElementAt(i);
            for (int i2 = 0; i2 < elementAt.getLength(); i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + elementAt.continuousVal(i2);
            }
        }
        double numberOfElements = dataSet.getNumberOfElements();
        for (int i4 = 0; i4 < dArr.length; i4++) {
            int i5 = i4;
            dArr[i5] = dArr[i5] / numberOfElements;
        }
        return dArr;
    }

    public static void main(String[] strArr) throws Exception {
        AlphabetContainer alphabetContainer = new AlphabetContainer(new ContinuousAlphabet());
        DataSet medianSmoothing = medianSmoothing(new DataSet(alphabetContainer, new StringExtractor(new File(strArr[0]), ImagePreloader.DEFAULT_PRIORITY, '#'), "\t"), 5);
        double[] mean = getMean(medianSmoothing);
        double max = ToolBox.max(mean);
        System.out.println(Arrays.toString(mean));
        System.out.println("num: " + medianSmoothing.getNumberOfElements());
        System.out.println("len: " + medianSmoothing.getElementLength());
        int elementLength = medianSmoothing.getElementLength() / 3;
        double elementLength2 = ((medianSmoothing.getElementLength() - (2 * elementLength)) * 4.0d) / 2.0d;
        DifferentiableEmission[] differentiableEmissionArr = new DifferentiableEmission[elementLength + 3];
        for (int i = 0; i < differentiableEmissionArr.length - 1; i++) {
            double d = 4.0d * 2.0d;
            if (i == 0) {
                d = elementLength2 * 2.0d;
            } else if (i == differentiableEmissionArr.length - 2) {
                d = 4.0d * 2.0d;
            }
            double length = (i / (differentiableEmissionArr.length - 1.0d)) * max;
            System.out.println(String.valueOf(i) + " " + length);
            differentiableEmissionArr[i] = new GaussianEmission(d, alphabetContainer, length, 1.0d, 1.0d, true);
        }
        differentiableEmissionArr[differentiableEmissionArr.length - 1] = new SilentEmission();
        String[] strArr2 = new String[(2 * elementLength) + 4];
        int[] iArr = new int[strArr2.length];
        strArr2[0] = SVGPathSegConstants.PATHSEG_CURVETO_CUBIC_SMOOTH_ABS_LETTER;
        iArr[0] = 0;
        strArr2[elementLength + 1] = "I";
        iArr[elementLength + 1] = differentiableEmissionArr.length - 2;
        strArr2[(2 * elementLength) + 2] = "E";
        iArr[(2 * elementLength) + 2] = 0;
        for (int i2 = 0; i2 < elementLength; i2++) {
            strArr2[i2 + 1] = "F" + i2;
            strArr2[(((2 * elementLength) + 2) - i2) - 1] = SVGConstants.SVG_B_VALUE + i2;
            iArr[i2 + 1] = i2 + 1;
            iArr[(((2 * elementLength) + 2) - i2) - 1] = i2 + 1;
        }
        strArr2[(2 * elementLength) + 3] = DOMKeyboardEvent.KEY_END;
        iArr[(2 * elementLength) + 3] = differentiableEmissionArr.length - 1;
        boolean[] zArr = new boolean[strArr2.length];
        Arrays.fill(zArr, true);
        LinkedList linkedList = new LinkedList();
        linkedList.add(new TransitionElement(new int[0], new int[1], new double[]{4.0d}));
        linkedList.add(new TransitionElement(new int[1], new int[]{0, 1}, new double[]{elementLength2, 4.0d}));
        linkedList.add(new TransitionElement(new int[]{elementLength + 1}, new int[]{elementLength + 2}, new double[]{4.0d}));
        linkedList.add(new TransitionElement(new int[]{(2 * elementLength) + 2}, new int[]{(2 * elementLength) + 2, (2 * elementLength) + 3}, new double[]{elementLength2, 4.0d}));
        for (int i3 = 0; i3 < elementLength; i3++) {
            linkedList.add(new TransitionElement(new int[]{i3 + 1}, new int[]{i3 + 2}, new double[]{4.0d}));
            linkedList.add(new TransitionElement(new int[]{elementLength + 2 + i3}, new int[]{elementLength + 2 + i3 + 1}, new double[]{4.0d}));
        }
        DifferentiableHigherOrderHMM differentiableHigherOrderHMM = new DifferentiableHigherOrderHMM(new ViterbiParameterSet(10, new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-6d), 4), strArr2, iArr, zArr, differentiableEmissionArr, false, 4.0d, (TransitionElement[]) linkedList.toArray(new TransitionElement[0]));
        System.out.println(differentiableHigherOrderHMM.getGraphvizRepresentation(null));
        differentiableHigherOrderHMM.train(medianSmoothing);
        System.out.println(differentiableHigherOrderHMM);
        System.out.println(differentiableHigherOrderHMM.getGraphvizRepresentation(new DecimalFormat("0.000")));
        for (int i4 = 0; i4 < 100; i4++) {
            System.out.println(differentiableHigherOrderHMM.getViterbiPathFor(medianSmoothing.getElementAt(i4)).getFirstElement());
        }
        double[] dArr = new double[strArr2.length];
        double[] dArr2 = new double[strArr2.length];
        for (int i5 = 0; i5 < medianSmoothing.getNumberOfElements(); i5++) {
            Sequence elementAt = medianSmoothing.getElementAt(i5);
            IntList firstElement = differentiableHigherOrderHMM.getViterbiPathFor(elementAt).getFirstElement();
            for (int i6 = 0; i6 < elementAt.getLength(); i6++) {
                int i7 = firstElement.get(i6);
                dArr[i7] = dArr[i7] + elementAt.continuousVal(i6);
                int i8 = firstElement.get(i6);
                dArr2[i8] = dArr2[i8] + 1.0d;
            }
        }
        for (int i9 = 0; i9 < dArr.length; i9++) {
            int i10 = i9;
            dArr[i10] = dArr[i10] / dArr2[i9];
        }
        System.out.println(Arrays.toString(dArr));
    }
}
