package defpackage;

import de.jstacs.data.DataSet;
import de.jstacs.data.DiscreteSequenceEnumerator;
import de.jstacs.data.alphabets.DNAAlphabetContainer;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.sequenceScores.statisticalModels.StatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModelFactory;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.BayesianNetworkDiffSM;
import de.jstacs.utils.ComparableElement;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;

/* loaded from: input_file:PWMTest2.class */
public class PWMTest2 {
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r4v11, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r4v5, types: [double[], double[][]] */
    public static void main(String[] strArr) throws Exception {
        double[] dArr = {1.0d, 4.0d, 8.0d, 16.0d};
        double[] dArr2 = {0.1d, 0.2d, 0.3d};
        double[][][] dArr3 = new double[dArr.length][dArr2.length][30];
        double[][][] dArr4 = new double[dArr.length][dArr2.length][30];
        System.out.println("useWeights: true");
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                System.out.println(String.valueOf(dArr[i]) + " " + dArr2[i2]);
                for (int i3 = 0; i3 < 30; i3++) {
                    System.out.println(i3);
                    BayesianNetworkDiffSM createPWM = DifferentiableStatisticalModelFactory.createPWM(DNAAlphabetContainer.SINGLETON, 8, dArr[i]);
                    createPWM.initializeFunctionRandomly(false);
                    double[][] pwm = createPWM.getPWM();
                    ComparableElement[] comparableElementArr = new ComparableElement[(int) Math.pow(4.0d, 8)];
                    Pair<DataSet, double[]> data = getData(8, createPWM, comparableElementArr, dArr2[i2]);
                    BayesianNetworkDiffSM createPWM2 = DifferentiableStatisticalModelFactory.createPWM(DNAAlphabetContainer.SINGLETON, 8, 0.0d);
                    createPWM2.initializeFunction(0, false, new DataSet[]{data.getFirstElement()}, new double[]{data.getSecondElement()});
                    double[][] pwm2 = createPWM2.getPWM();
                    BayesianNetworkDiffSM createPWM3 = DifferentiableStatisticalModelFactory.createPWM(DNAAlphabetContainer.SINGLETON, 8, 0.0d);
                    train(createPWM3, data.getFirstElement(), data.getSecondElement(), comparableElementArr, dArr2[i2], 1 != 0 ? null : createPWM2, false);
                    double[][] pwm3 = createPWM3.getPWM();
                    Pair<DataSet, double[]> data2 = getData(8, createPWM3, comparableElementArr, dArr2[i2]);
                    createPWM2.initializeFunction(0, false, new DataSet[]{data2.getFirstElement()}, new double[]{data2.getSecondElement()});
                    dArr3[i][i2][i3] = getKL(pwm2, createPWM2.getPWM());
                    dArr4[i][i2][i3] = getKL(pwm, pwm3);
                }
            }
        }
        for (int i4 = 0; i4 < dArr.length; i4++) {
            for (int i5 = 0; i5 < dArr2.length; i5++) {
                System.out.println(String.valueOf(dArr[i4]) + "\t" + dArr2[i5] + "\t" + ToolBox.mean(0, dArr3[i4][i5].length, dArr3[i4][i5]) + "\t" + ToolBox.sd(0, dArr3[i4][i5].length, dArr3[i4][i5]) + "\t" + ToolBox.mean(0, dArr4[i4][i5].length, dArr4[i4][i5]) + "\t" + ToolBox.sd(0, dArr4[i4][i5].length, dArr4[i4][i5]));
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r4v2, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r4v5, types: [double[], double[][]] */
    private static void train(DifferentiableStatisticalModel differentiableStatisticalModel, DataSet dataSet, double[] dArr, ComparableElement<Sequence, Double>[] comparableElementArr, double d, DifferentiableStatisticalModel differentiableStatisticalModel2, boolean z) throws Exception {
        double d2;
        int numberOfElements = dataSet.getNumberOfElements();
        if (differentiableStatisticalModel2 != null && z) {
            numberOfElements = getData(differentiableStatisticalModel2.getLength(), differentiableStatisticalModel2, comparableElementArr, d).getFirstElement().getNumberOfElements();
        }
        Sequence[] sequenceArr = new Sequence[comparableElementArr.length];
        double[] dArr2 = new double[comparableElementArr.length];
        for (int i = 0; i < comparableElementArr.length; i++) {
            sequenceArr[i] = comparableElementArr[(comparableElementArr.length - i) - 1].getElement();
            if (i >= numberOfElements) {
                dArr2[i] = (1.0d - d) / (comparableElementArr.length - numberOfElements);
            } else if (differentiableStatisticalModel2 == null) {
                dArr2[i] = comparableElementArr[(comparableElementArr.length - i) - 1].getWeight().doubleValue();
            } else {
                dArr2[i] = Math.exp(differentiableStatisticalModel2.getLogProbFor(comparableElementArr[(comparableElementArr.length - i) - 1].getElement()));
            }
        }
        double sum = d / ToolBox.sum(0, numberOfElements, dArr2);
        for (int i2 = 0; i2 < numberOfElements; i2++) {
            int i3 = i2;
            dArr2[i3] = dArr2[i3] * sum;
        }
        double d3 = dArr2[numberOfElements - 1];
        DataSet dataSet2 = new DataSet("", sequenceArr);
        differentiableStatisticalModel.initializeFunction(0, false, new DataSet[]{dataSet2}, new double[]{dArr2});
        double d4 = Double.NEGATIVE_INFINITY;
        do {
            d2 = d4;
            d4 = 0.0d;
            for (int i4 = numberOfElements; i4 < sequenceArr.length; i4++) {
                dArr2[i4] = Math.exp(differentiableStatisticalModel.getLogProbFor(dataSet2.getElementAt(i4)));
                d4 += dArr2[i4];
            }
            double max = d3 / ToolBox.max(numberOfElements, dArr2.length, dArr2);
            differentiableStatisticalModel.initializeFunction(0, false, new DataSet[]{dataSet2}, new double[]{dArr2});
        } while (Math.abs(d4 - d2) > 1.0E-6d);
    }

    private static double getKL(double[][] dArr, double[][] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                double[] dArr3 = dArr[i];
                int i3 = i2;
                dArr3[i3] = dArr3[i3] + 1.0E-6d;
                double[] dArr4 = dArr2[i];
                int i4 = i2;
                dArr4[i4] = dArr4[i4] + 1.0E-6d;
            }
            Normalisation.sumNormalisation(dArr[i]);
            Normalisation.sumNormalisation(dArr2[i]);
            double d2 = 0.0d;
            for (int i5 = 0; i5 < dArr[i].length; i5++) {
                d2 += dArr2[i][i5] * Math.log(dArr2[i][i5] / dArr[i][i5]);
            }
            d += d2;
        }
        return d / dArr.length;
    }

    private static Pair<DataSet, double[]> getData(int i, StatisticalModel statisticalModel, ComparableElement<Sequence, Double>[] comparableElementArr, double d) throws Exception {
        DiscreteSequenceEnumerator discreteSequenceEnumerator = new DiscreteSequenceEnumerator(DNAAlphabetContainer.SINGLETON, i, false);
        int i2 = 0;
        while (discreteSequenceEnumerator.hasMoreElements()) {
            Sequence nextElement2 = discreteSequenceEnumerator.nextElement2();
            comparableElementArr[i2] = new ComparableElement<>(nextElement2, Double.valueOf(Math.exp(statisticalModel.getLogProbFor(nextElement2))));
            i2++;
        }
        Arrays.sort(comparableElementArr);
        double d2 = 0.0d;
        int i3 = 0;
        int length = comparableElementArr.length - 1;
        while (true) {
            if (length < 0) {
                break;
            }
            d2 += comparableElementArr[length].getWeight().doubleValue();
            if (d2 >= d) {
                i3 = comparableElementArr.length - length;
                break;
            }
            length--;
        }
        Sequence[] sequenceArr = new Sequence[i3];
        double[] dArr = new double[i3];
        for (int i4 = 0; i4 < i3; i4++) {
            sequenceArr[i4] = comparableElementArr[(comparableElementArr.length - i4) - 1].getElement();
            dArr[i4] = comparableElementArr[(comparableElementArr.length - i4) - 1].getWeight().doubleValue();
        }
        return new Pair<>(new DataSet("", sequenceArr), dArr);
    }
}
