/*
 * Decompiled with CFR 0.152.
 */
import de.jstacs.data.DataSet;
import de.jstacs.data.DiscreteSequenceEnumerator;
import de.jstacs.data.alphabets.DNAAlphabetContainer;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.sequenceScores.statisticalModels.StatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModelFactory;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.BayesianNetworkDiffSM;
import de.jstacs.utils.ComparableElement;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;

public class PWMTest {
    public static void main(String[] args) throws Exception {
        int p;
        int length = 8;
        double[] ess = new double[]{1.0, 4.0, 8.0, 16.0, 64.0};
        double[] ess2 = new double[]{0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.125, 0.15, 0.175, 0.2, 0.225, 0.25, 0.275, 0.3};
        double[] perc = new double[]{0.1, 0.2, 0.3};
        int reps = 100;
        double[][][][] kls = new double[ess.length][perc.length][ess2.length][reps];
        int e = 0;
        while (e < ess.length) {
            p = 0;
            while (p < perc.length) {
                int r = 0;
                while (r < reps) {
                    BayesianNetworkDiffSM pwm = DifferentiableStatisticalModelFactory.createPWM(DNAAlphabetContainer.SINGLETON, length, ess[e]);
                    pwm.initializeFunctionRandomly(false);
                    ComparableElement[] mers = new ComparableElement[(int)Math.pow(4.0, length)];
                    Pair<DataSet, double[]> data = PWMTest.getData(length, pwm, mers, perc[p]);
                    BayesianNetworkDiffSM pwm2 = DifferentiableStatisticalModelFactory.createPWM(DNAAlphabetContainer.SINGLETON, length, 0.0);
                    pwm2.initializeFunction(0, false, new DataSet[]{data.getFirstElement()}, new double[][]{data.getSecondElement()});
                    double[][] tpwm = pwm2.getPWM();
                    Pair<DataSet, double[]> data2 = PWMTest.getData(length, pwm2, mers, perc[p]);
                    pwm2.initializeFunction(0, false, new DataSet[]{data2.getFirstElement()}, new double[][]{data2.getSecondElement()});
                    int e2 = 0;
                    while (e2 < ess2.length) {
                        BayesianNetworkDiffSM pwm3 = DifferentiableStatisticalModelFactory.createPWM(DNAAlphabetContainer.SINGLETON, length, ess2[e2]);
                        pwm3.initializeFunction(0, false, new DataSet[]{data.getFirstElement()}, new double[][]{data.getSecondElement()});
                        data2 = PWMTest.getData(length, pwm3, mers, perc[p]);
                        pwm2.initializeFunction(0, false, new DataSet[]{data2.getFirstElement()}, new double[][]{data2.getSecondElement()});
                        double[][] test = pwm2.getPWM();
                        kls[e][p][e2][r] = PWMTest.getKL(tpwm, test);
                        ++e2;
                    }
                    ++r;
                }
                ++p;
            }
            ++e;
        }
        e = 0;
        while (e < ess.length) {
            p = 0;
            while (p < perc.length) {
                double minKl = Double.POSITIVE_INFINITY;
                double sdMin = 0.0;
                double essMin = 0.0;
                int e2 = 0;
                while (e2 < ess2.length) {
                    double mean = ToolBox.mean(0, kls[e][p][e2].length, kls[e][p][e2]);
                    double sd = ToolBox.sd(0, kls[e][p][e2].length, kls[e][p][e2]);
                    if (mean < minKl) {
                        minKl = mean;
                        sdMin = sd;
                        essMin = ess2[e2];
                    }
                    ++e2;
                }
                System.out.println(String.valueOf(ess[e]) + "\t" + perc[p] + "\t" + minKl + "\t" + sdMin + "\t" + essMin);
                ++p;
            }
            ++e;
        }
    }

    private static double getKL(double[][] tpwm, double[][] test) {
        double mkl = 0.0;
        int i = 0;
        while (i < tpwm.length) {
            int j = 0;
            while (j < tpwm[i].length) {
                double[] dArray = tpwm[i];
                int n = j;
                dArray[n] = dArray[n] + 1.0E-6;
                double[] dArray2 = test[i];
                int n2 = j++;
                dArray2[n2] = dArray2[n2] + 1.0E-6;
            }
            Normalisation.sumNormalisation(tpwm[i]);
            Normalisation.sumNormalisation(test[i]);
            double kl = 0.0;
            int j2 = 0;
            while (j2 < tpwm[i].length) {
                kl += test[i][j2] * Math.log(test[i][j2] / tpwm[i][j2]);
                ++j2;
            }
            mkl += kl;
            ++i;
        }
        return mkl /= (double)tpwm.length;
    }

    private static Pair<DataSet, double[]> getData(int length, StatisticalModel pwm, ComparableElement<Sequence, Double>[] mers, double perc) throws Exception {
        DiscreteSequenceEnumerator en = new DiscreteSequenceEnumerator(DNAAlphabetContainer.SINGLETON, length, false);
        int i = 0;
        while (en.hasMoreElements()) {
            Object seq = en.nextElement();
            mers[i] = new ComparableElement<Object, Double>(seq, Math.exp(pwm.getLogProbFor((Sequence)seq)));
            ++i;
        }
        Arrays.sort(mers);
        double sum = 0.0;
        int top = 0;
        i = mers.length - 1;
        while (i >= 0) {
            if ((sum += mers[i].getWeight().doubleValue()) >= perc) {
                top = mers.length - i;
                break;
            }
            --i;
        }
        Sequence[] topS = new Sequence[top];
        double[] w = new double[top];
        i = 0;
        while (i < top) {
            topS[i] = mers[mers.length - i - 1].getElement();
            w[i] = mers[mers.length - i - 1].getWeight();
            ++i;
        }
        return new Pair<DataSet, double[]>(new DataSet("", topS), w);
    }
}

