/*
 * Decompiled with CFR 0.152.
 */
import de.jstacs.data.DataSet;
import de.jstacs.data.DiscreteSequenceEnumerator;
import de.jstacs.data.alphabets.DNAAlphabetContainer;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.sequenceScores.statisticalModels.StatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModelFactory;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.BayesianNetworkDiffSM;
import de.jstacs.utils.ComparableElement;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;

public class PWMTest2 {
    public static void main(String[] args) throws Exception {
        int p;
        int length = 8;
        double[] ess = new double[]{1.0, 4.0, 8.0, 16.0};
        double[] perc = new double[]{0.1, 0.2, 0.3};
        int reps = 30;
        double[][][] kls = new double[ess.length][perc.length][reps];
        double[][][] kls2 = new double[ess.length][perc.length][reps];
        boolean useWeights = true;
        boolean orderByQ = false;
        System.out.println("useWeights: " + useWeights);
        int e = 0;
        while (e < ess.length) {
            p = 0;
            while (p < perc.length) {
                System.out.println(String.valueOf(ess[e]) + " " + perc[p]);
                int r = 0;
                while (r < reps) {
                    System.out.println(r);
                    BayesianNetworkDiffSM pwm = DifferentiableStatisticalModelFactory.createPWM(DNAAlphabetContainer.SINGLETON, length, ess[e]);
                    pwm.initializeFunctionRandomly(false);
                    double[][] opwm = pwm.getPWM();
                    ComparableElement[] mers = new ComparableElement[(int)Math.pow(4.0, length)];
                    Pair<DataSet, double[]> data = PWMTest2.getData(length, pwm, mers, perc[p]);
                    BayesianNetworkDiffSM pwm2 = DifferentiableStatisticalModelFactory.createPWM(DNAAlphabetContainer.SINGLETON, length, 0.0);
                    pwm2.initializeFunction(0, false, new DataSet[]{data.getFirstElement()}, new double[][]{data.getSecondElement()});
                    double[][] tpwm = pwm2.getPWM();
                    BayesianNetworkDiffSM pwm3 = DifferentiableStatisticalModelFactory.createPWM(DNAAlphabetContainer.SINGLETON, length, 0.0);
                    PWMTest2.train(pwm3, data.getFirstElement(), data.getSecondElement(), mers, perc[p], useWeights ? null : pwm2, orderByQ);
                    double[][] lpwm = pwm3.getPWM();
                    Pair<DataSet, double[]> data2 = PWMTest2.getData(length, pwm3, mers, perc[p]);
                    pwm2.initializeFunction(0, false, new DataSet[]{data2.getFirstElement()}, new double[][]{data2.getSecondElement()});
                    double[][] test = pwm2.getPWM();
                    kls[e][p][r] = PWMTest2.getKL(tpwm, test);
                    kls2[e][p][r] = PWMTest2.getKL(opwm, lpwm);
                    ++r;
                }
                ++p;
            }
            ++e;
        }
        e = 0;
        while (e < ess.length) {
            p = 0;
            while (p < perc.length) {
                double mean = ToolBox.mean(0, kls[e][p].length, kls[e][p]);
                double sd = ToolBox.sd(0, kls[e][p].length, kls[e][p]);
                double mean2 = ToolBox.mean(0, kls2[e][p].length, kls2[e][p]);
                double sd2 = ToolBox.sd(0, kls2[e][p].length, kls2[e][p]);
                System.out.println(String.valueOf(ess[e]) + "\t" + perc[p] + "\t" + mean + "\t" + sd + "\t" + mean2 + "\t" + sd2);
                ++p;
            }
            ++e;
        }
    }

    private static void train(DifferentiableStatisticalModel pwm3, DataSet firstElement, double[] secondElement, ComparableElement<Sequence, Double>[] mers, double perc, DifferentiableStatisticalModel q, boolean orderByQ) throws Exception {
        int top = firstElement.getNumberOfElements();
        if (q != null && orderByQ) {
            Pair<DataSet, double[]> pair = PWMTest2.getData(q.getLength(), q, mers, perc);
            top = pair.getFirstElement().getNumberOfElements();
        }
        Sequence[] all = new Sequence[mers.length];
        double[] wall = new double[mers.length];
        int i = 0;
        while (i < mers.length) {
            all[i] = mers[mers.length - i - 1].getElement();
            wall[i] = i < top ? (q == null ? mers[mers.length - i - 1].getWeight() : Math.exp(q.getLogProbFor(mers[mers.length - i - 1].getElement()))) : (1.0 - perc) / (double)(mers.length - top);
            ++i;
        }
        double norm = ToolBox.sum(0, top, wall);
        double fac = perc / norm;
        int i2 = 0;
        while (i2 < top) {
            int n = i2++;
            wall[n] = wall[n] * fac;
        }
        double min = wall[top - 1];
        DataSet data = new DataSet("", all);
        pwm3.initializeFunction(0, false, new DataSet[]{data}, new double[][]{wall});
        double oldLL = Double.NEGATIVE_INFINITY;
        double ll = Double.NEGATIVE_INFINITY;
        do {
            oldLL = ll;
            ll = 0.0;
            int i3 = top;
            while (i3 < all.length) {
                wall[i3] = Math.exp(pwm3.getLogProbFor(data.getElementAt(i3)));
                ll += wall[i3];
                ++i3;
            }
            double max = ToolBox.max(top, wall.length, wall);
            double scale = min / max;
            pwm3.initializeFunction(0, false, new DataSet[]{data}, new double[][]{wall});
        } while (Math.abs(ll - oldLL) > 1.0E-6);
    }

    private static double getKL(double[][] tpwm, double[][] test) {
        double mkl = 0.0;
        int i = 0;
        while (i < tpwm.length) {
            int j = 0;
            while (j < tpwm[i].length) {
                double[] dArray = tpwm[i];
                int n = j;
                dArray[n] = dArray[n] + 1.0E-6;
                double[] dArray2 = test[i];
                int n2 = j++;
                dArray2[n2] = dArray2[n2] + 1.0E-6;
            }
            Normalisation.sumNormalisation(tpwm[i]);
            Normalisation.sumNormalisation(test[i]);
            double kl = 0.0;
            int j2 = 0;
            while (j2 < tpwm[i].length) {
                kl += test[i][j2] * Math.log(test[i][j2] / tpwm[i][j2]);
                ++j2;
            }
            mkl += kl;
            ++i;
        }
        return mkl /= (double)tpwm.length;
    }

    private static Pair<DataSet, double[]> getData(int length, StatisticalModel pwm, ComparableElement<Sequence, Double>[] mers, double perc) throws Exception {
        DiscreteSequenceEnumerator en = new DiscreteSequenceEnumerator(DNAAlphabetContainer.SINGLETON, length, false);
        int i = 0;
        while (en.hasMoreElements()) {
            Object seq = en.nextElement();
            mers[i] = new ComparableElement<Object, Double>(seq, Math.exp(pwm.getLogProbFor((Sequence)seq)));
            ++i;
        }
        Arrays.sort(mers);
        double sum = 0.0;
        int top = 0;
        i = mers.length - 1;
        while (i >= 0) {
            if ((sum += mers[i].getWeight().doubleValue()) >= perc) {
                top = mers.length - i;
                break;
            }
            --i;
        }
        Sequence[] topS = new Sequence[top];
        double[] w = new double[top];
        i = 0;
        while (i < top) {
            topS[i] = mers[mers.length - i - 1].getElement();
            w[i] = mers[mers.length - i - 1].getWeight();
            ++i;
        }
        return new Pair<DataSet, double[]>(new DataSet("", topS), w);
    }
}

