package de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels;

import cern.colt.matrix.impl.AbstractFormatter;
import de.jstacs.Storable;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Random;
import org.biojava.bio.program.tagvalue.TagValueParser;
import org.biojavax.bio.seq.Position;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/directedGraphicalModels/BNDiffSMParameterTree.class */
public class BNDiffSMParameterTree implements Cloneable, Storable {
    private int pos;
    private int[] contextPoss;
    private TreeElement root;
    private AlphabetContainer alphabet;
    private int firstParent;
    private int[] firstChildren;
    private static Random r = new Random();

    /* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/directedGraphicalModels/BNDiffSMParameterTree$TreeElement.class */
    public class TreeElement implements Storable, Cloneable {
        private int contextPos;
        private TreeElement[] children;
        private BNDiffSMParameter[] pars;
        private Double fullNormalizer;
        private Double[] symT;
        private int contNum;

        private TreeElement(int i, AlphabetContainer alphabetContainer) {
            this.contNum = i;
            if (i >= BNDiffSMParameterTree.this.contextPoss.length) {
                this.contextPos = -1;
                this.pars = new BNDiffSMParameter[(int) alphabetContainer.getAlphabetLengthAt(BNDiffSMParameterTree.this.pos)];
                this.fullNormalizer = null;
                this.symT = new Double[this.pars.length];
                return;
            }
            this.contextPos = BNDiffSMParameterTree.this.contextPoss[i];
            this.children = new TreeElement[(int) alphabetContainer.getAlphabetLengthAt(this.contextPos)];
            for (int i2 = 0; i2 < alphabetContainer.getAlphabetLengthAt(this.contextPos); i2++) {
                this.children[i2] = new TreeElement(i + 1, alphabetContainer);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void appendToBuffer(StringBuffer stringBuffer, String str, NumberFormat numberFormat) {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i].appendToBuffer(stringBuffer, String.valueOf(str) + "X_" + this.contextPos + " = " + BNDiffSMParameterTree.this.alphabet.getSymbol(this.contextPos, i) + ", ", numberFormat);
                }
                return;
            }
            double[] dArr = new double[this.pars.length];
            for (int i2 = 0; i2 < this.pars.length; i2++) {
                dArr[i2] = this.pars[i2].getValue() + this.pars[i2].getLogZ();
            }
            double logSum = Normalisation.getLogSum(dArr);
            for (int i3 = 0; i3 < this.pars.length; i3++) {
                stringBuffer.append("P(X_" + BNDiffSMParameterTree.this.pos + " = " + BNDiffSMParameterTree.this.alphabet.getSymbol(BNDiffSMParameterTree.this.pos, i3) + " | " + str + "c)=" + numberFormat.format(Math.exp((this.pars[i3].getValue() + this.pars[i3].getLogZ()) - logSum)));
                if (i3 < this.pars.length - 1) {
                    stringBuffer.append("\t");
                }
            }
            stringBuffer.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void normalizeParameters() {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i].normalizeParameters();
                }
                return;
            }
            double[] dArr = new double[this.pars.length];
            for (int i2 = 0; i2 < this.pars.length; i2++) {
                dArr[i2] = this.pars[i2].getValue() + this.pars[i2].getLogZ();
            }
            double logSum = Normalisation.getLogSum(dArr);
            for (int i3 = 0; i3 < this.pars.length; i3++) {
                this.pars[i3].setValue((this.pars[i3].getValue() + this.pars[i3].getLogZ()) - logSum);
            }
            if (this.pars[this.pars.length - 1].isFree()) {
                return;
            }
            for (int i4 = 0; i4 < this.pars.length; i4++) {
                this.pars[i4].setValue(this.pars[i4].getValue() - this.pars[this.pars.length - 1].getValue());
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void insertProbs(double[] dArr) throws Exception {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i].insertProbs(dArr);
                }
                return;
            }
            double[] dArr2 = new double[this.pars.length];
            for (int i2 = 0; i2 < this.pars.length; i2++) {
                dArr2[i2] = this.pars[i2].getValue() + this.pars[i2].getLogZ();
            }
            Normalisation.logSumNormalisation(dArr2);
            for (int i3 = 0; i3 < dArr.length; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + (getContextProbability() * dArr2[i3]);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void startBackward(int[][] iArr, BNDiffSMParameterTree[] bNDiffSMParameterTreeArr, int[][] iArr2, int i) throws RuntimeException {
            if (this.children != null) {
                iArr[i][0] = this.contextPos;
                for (int i2 = 0; i2 < this.children.length; i2++) {
                    iArr[i][1] = i2;
                    this.children[i2].startBackward(iArr, bNDiffSMParameterTreeArr, iArr2, i + 1);
                }
                return;
            }
            iArr[i][0] = BNDiffSMParameterTree.this.pos;
            for (int i3 = 0; i3 < this.pars.length; i3++) {
                iArr[i][1] = this.pars[i3].symbol;
                bNDiffSMParameterTreeArr[BNDiffSMParameterTree.this.pos].getLogT(iArr, bNDiffSMParameterTreeArr, iArr2);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double getLogT(int[][] iArr, int[][] iArr2, BNDiffSMParameterTree[] bNDiffSMParameterTreeArr, int[][] iArr3, int i) throws RuntimeException {
            if (this.children != null) {
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    if (iArr[i2][0] == this.contextPos) {
                        iArr2[i][0] = iArr[i2][0];
                        iArr2[i][1] = iArr[i2][1];
                        return this.children[iArr[i2][1]].getLogT(iArr, iArr2, bNDiffSMParameterTreeArr, iArr3, i + 1);
                    }
                }
                throw new RuntimeException("Correct context not found for depth " + i + " at position " + BNDiffSMParameterTree.this.pos + Position.IN_RANGE);
            }
            for (int i3 = 0; i3 < iArr.length; i3++) {
                if (iArr[i3][0] == BNDiffSMParameterTree.this.pos) {
                    for (int i4 = 0; i4 < this.pars.length; i4++) {
                        if (this.pars[i4].symbol == iArr[i3][1]) {
                            if (this.symT[i4] != null) {
                                return this.symT[i4].doubleValue();
                            }
                            int i5 = BNDiffSMParameterTree.this.firstParent;
                            if (i5 == -1) {
                                this.pars[i4].setLogT(Double.valueOf(0.0d));
                                return this.pars[i4].getValue();
                            }
                            if (bNDiffSMParameterTreeArr[i5].contextPoss.length < BNDiffSMParameterTree.this.contextPoss.length) {
                                double logT = bNDiffSMParameterTreeArr[i5].getLogT(iArr2, bNDiffSMParameterTreeArr, iArr3);
                                int[] iArr4 = bNDiffSMParameterTreeArr[i5].firstChildren;
                                for (int i6 = 0; i6 < iArr4.length; i6++) {
                                    if (iArr4[i6] != BNDiffSMParameterTree.this.pos) {
                                        logT += bNDiffSMParameterTreeArr[iArr4[i6]].getLogZ(iArr2, bNDiffSMParameterTreeArr);
                                    }
                                }
                                this.pars[i4].setLogT(Double.valueOf(logT));
                                this.symT[i4] = Double.valueOf(this.pars[i4].getValue() + logT);
                                return this.symT[i4].doubleValue();
                            }
                            int[] iArr5 = bNDiffSMParameterTreeArr[i5].contextPoss;
                            int i7 = Integer.MAX_VALUE;
                            int i8 = -1;
                            for (int i9 = 0; i9 < iArr5.length; i9++) {
                                if (iArr3[iArr5[i9]][1] < i7) {
                                    i7 = iArr3[iArr5[i9]][1];
                                    i8 = iArr5[i9];
                                }
                            }
                            iArr2[i][0] = i8;
                            int alphabetLengthAt = (int) BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(i8);
                            double[] dArr = new double[alphabetLengthAt];
                            byte b = 0;
                            while (true) {
                                byte b2 = b;
                                if (b2 >= alphabetLengthAt) {
                                    double logSum = Normalisation.getLogSum(dArr);
                                    this.pars[i4].setLogT(Double.valueOf(logSum));
                                    this.symT[i4] = Double.valueOf(this.pars[i4].getValue() + logSum);
                                    return this.symT[i4].doubleValue();
                                }
                                iArr2[i][1] = b2;
                                dArr[b2] = bNDiffSMParameterTreeArr[i5].getLogT(iArr2, bNDiffSMParameterTreeArr, iArr3);
                                int[] iArr6 = bNDiffSMParameterTreeArr[i5].firstChildren;
                                for (int i10 = 0; i10 < iArr6.length; i10++) {
                                    if (iArr6[i10] != BNDiffSMParameterTree.this.pos) {
                                        dArr[b2] = dArr[b2] + bNDiffSMParameterTreeArr[iArr6[i10]].getLogZ(iArr2, bNDiffSMParameterTreeArr);
                                    }
                                }
                                b = (byte) (b2 + 1);
                            }
                        }
                    }
                }
            }
            throw new RuntimeException("BNDiffSMParameter value not defined in context.");
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double getLogZ(int[][] iArr, int[][] iArr2, BNDiffSMParameterTree[] bNDiffSMParameterTreeArr, int i) throws RuntimeException {
            if (this.children != null) {
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    if (iArr[i2][0] == this.contextPos) {
                        iArr2[i][0] = iArr[i2][0];
                        iArr2[i][1] = iArr[i2][1];
                        return this.children[iArr[i2][1]].getLogZ(iArr, iArr2, bNDiffSMParameterTreeArr, i + 1);
                    }
                }
                throw new RuntimeException("Correct context could not be found at position " + BNDiffSMParameterTree.this.pos + " and depth " + i);
            }
            if (this.fullNormalizer != null) {
                return this.fullNormalizer.doubleValue();
            }
            double[] dArr = new double[this.pars.length];
            for (int i3 = 0; i3 < this.pars.length; i3++) {
                int[] iArr3 = BNDiffSMParameterTree.this.firstChildren;
                if (iArr3 == null) {
                    throw new RuntimeException("First children of parameter " + this.pars[i3].getIndex() + " not defined.");
                }
                iArr2[i][0] = this.pars[i3].getPosition();
                iArr2[i][1] = this.pars[i3].symbol;
                double d = 0.0d;
                for (int i4 : iArr3) {
                    d += bNDiffSMParameterTreeArr[i4].getLogZ(iArr2, bNDiffSMParameterTreeArr);
                }
                this.pars[i3].setLogZ(Double.valueOf(d));
                dArr[i3] = this.pars[i3].getValue() + d;
            }
            this.fullNormalizer = Double.valueOf(Normalisation.getLogSum(dArr));
            return this.fullNormalizer.doubleValue();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void invalidateNormalizers() {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i].invalidateNormalizers();
                }
            } else {
                for (int i2 = 0; i2 < this.pars.length; i2++) {
                    this.pars[i2].invalidateNormalizers();
                    this.symT[i2] = null;
                }
            }
            this.fullNormalizer = null;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void cloneRest(TreeElement treeElement) throws CloneNotSupportedException {
            this.contextPos = treeElement.contextPos;
            if (this.children != null) {
                this.children = new TreeElement[this.children.length];
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i] = new TreeElement(treeElement.children[i].contNum, BNDiffSMParameterTree.this.alphabet);
                    this.children[i].cloneRest(treeElement.children[i]);
                }
            } else {
                this.children = null;
            }
            if (this.pars != null) {
                this.pars = new BNDiffSMParameter[this.pars.length];
                for (int i2 = 0; i2 < this.pars.length; i2++) {
                    this.pars[i2] = treeElement.pars[i2].m117clone();
                }
                this.fullNormalizer = null;
                this.symT = new Double[this.pars.length];
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void divideByUnfree() {
            if (this.pars == null) {
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i].divideByUnfree();
                }
                return;
            }
            double value = this.pars[this.pars.length - 1].getValue();
            for (int i2 = 0; i2 < this.pars.length; i2++) {
                if (Double.isNaN(this.pars[i2].getValue() - value) || Double.isInfinite(this.pars[i2].getValue() - value)) {
                    this.pars[i2].setValue(0.0d);
                } else {
                    this.pars[i2].setValue(this.pars[i2].getValue() - value);
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public LinkedList<BNDiffSMParameter> linearizeParameters(LinkedList<BNDiffSMParameter> linkedList) {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i].linearizeParameters(linkedList);
                }
            } else {
                for (int i2 = 0; i2 < this.pars.length; i2++) {
                    linkedList.add(this.pars[i2]);
                }
            }
            return linkedList;
        }

        public TreeElement(StringBuffer stringBuffer) throws NonParsableException {
            StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, "treeElement");
            this.contNum = ((Integer) XMLParser.extractObjectForTags(extractForTag, "contNum", Integer.TYPE)).intValue();
            this.contextPos = ((Integer) XMLParser.extractObjectForTags(extractForTag, "contextPos", Integer.TYPE)).intValue();
            this.children = (TreeElement[]) XMLParser.extractObjectAndAttributesForTags(extractForTag, "children", null, null, TreeElement[].class, BNDiffSMParameterTree.class, BNDiffSMParameterTree.this);
            this.pars = (BNDiffSMParameter[]) XMLParser.extractObjectForTags(extractForTag, "pars", BNDiffSMParameter[].class);
            if (this.pars != null) {
                this.symT = new Double[this.pars.length];
                this.fullNormalizer = null;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setParameterFor(int i, int i2, int[][] iArr, BNDiffSMParameter bNDiffSMParameter) {
            if (this.children == null) {
                this.pars[i2] = bNDiffSMParameter;
                return;
            }
            for (int i3 = 1; i3 < iArr[i].length; i3++) {
                this.children[iArr[i][i3]].setParameterFor(i + 1, i2, iArr, bNDiffSMParameter);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void print() {
            System.out.println(this.contextPos);
            if (this.children == null) {
                for (int i = 0; i < this.pars.length; i++) {
                    this.pars[i].print();
                }
                return;
            }
            for (int i2 = 0; i2 < this.children.length; i2++) {
                System.out.println("child " + i2 + ":");
                this.children[i2].print();
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void normalizePlugInParameters() {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i].normalizePlugInParameters();
                }
                return;
            }
            double d = 0.0d;
            for (int i2 = 0; i2 < this.pars.length; i2++) {
                d += this.pars[i2].getCounts();
            }
            if (d > 0.0d) {
                for (int i3 = 0; i3 < this.pars.length; i3++) {
                    this.pars[i3].setValue(Math.log(this.pars[i3].getCounts() / d));
                }
                return;
            }
            for (int i4 = 0; i4 < this.pars.length; i4++) {
                this.pars[i4].setValue(-Math.log(this.pars.length));
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public BNDiffSMParameter getParameterFor(Sequence sequence, int i) {
            return this.children != null ? this.children[sequence.discreteVal(this.contextPos + i)].getParameterFor(sequence, i) : this.pars[sequence.discreteVal(BNDiffSMParameterTree.this.pos + i)];
        }

        @Override // de.jstacs.Storable
        public StringBuffer toXML() {
            StringBuffer stringBuffer = new StringBuffer();
            XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.contNum), "contNum");
            XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.contextPos), "contextPos");
            XMLParser.appendObjectWithTags(stringBuffer, this.children, "children");
            XMLParser.appendObjectWithTags(stringBuffer, this.pars, "pars");
            XMLParser.addTags(stringBuffer, "treeElement");
            return stringBuffer;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void drawKLDivergences(double d, double[] dArr, int i, int i2, double[][][] dArr2, double d2, int i3, int i4) {
            if (this.children != null && i4 < dArr2.length - 1) {
                for (int i5 = 0; i5 < this.children.length; i5++) {
                    this.children[i5].drawKLDivergences(d, dArr, i, i2, dArr2, d2, i3 + (i5 * dArr2[i4].length), i4 + 1);
                }
                return;
            }
            if (this.children != null) {
                for (int i6 = 0; i6 < this.children.length; i6++) {
                    this.children[i6].drawKLDivergences(d, dArr, i, i2, dArr2, d2, i3, i4);
                }
                return;
            }
            double[] dArr3 = dArr2[i4][i3];
            double[] dArr4 = new double[this.pars.length];
            double contextProbability = getContextProbability();
            for (int i7 = 0; i7 < dArr4.length; i7++) {
                dArr4[i7] = d2 * contextProbability * dArr3[i7];
            }
            DirichletMRGParams dirichletMRGParams = new DirichletMRGParams(dArr4);
            double[] dArr5 = new double[dArr4.length];
            for (int i8 = i; i8 < i2; i8++) {
                DirichletMRG.DEFAULT_INSTANCE.generate(dArr5, 0, dArr5.length, dirichletMRGParams);
                for (int i9 = 0; i9 < dArr5.length; i9++) {
                    if (dArr5[i9] > 0.0d) {
                        int i10 = i8;
                        dArr[i10] = dArr[i10] + (d * contextProbability * dArr5[i9] * Math.log(dArr5[i9] / dArr3[i9]));
                    }
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setNewParameters(double[] dArr, double[][][][] dArr2, int i, int i2) {
            int alphabetLengthAt = (int) BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.contextPos);
            if (this.children == null) {
                fill(getMarginal(dArr, dArr2, i, i2));
                return;
            }
            int pow = (int) Math.pow(alphabetLengthAt, i2);
            for (int i3 = 0; i3 < this.children.length; i3++) {
                this.children[i3].setNewParameters(dArr, dArr2, i + (i3 * pow), i2 + 1);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double getWeightedKLDivergence(double[][][] dArr, int i, int i2) {
            double d = 0.0d;
            if (this.children != null && i2 < dArr.length - 1) {
                for (int i3 = 0; i3 < this.children.length; i3++) {
                    d += this.children[i3].getWeightedKLDivergence(dArr, i + (i3 * dArr[i2].length), i2 + 1);
                }
            } else if (this.children != null) {
                for (int i4 = 0; i4 < this.children.length; i4++) {
                    d += this.children[i4].getWeightedKLDivergence(dArr, i, i2);
                }
            } else {
                double d2 = 0.0d;
                double[] dArr2 = new double[this.pars.length];
                for (int i5 = 0; i5 < this.pars.length; i5++) {
                    dArr2[i5] = this.pars[i5].getValue() + this.pars[i5].getLogZ();
                }
                double logSum = Normalisation.getLogSum(dArr2);
                for (int i6 = 0; i6 < this.pars.length; i6++) {
                    double exp = Math.exp((this.pars[i6].getValue() + this.pars[i6].getLogZ()) - logSum);
                    if (exp > 0.0d) {
                        d2 += exp * Math.log(exp / dArr[i2][i][i6]);
                    }
                }
                d = d2 * getContextProbability();
            }
            return d;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double getWeightedKLDivergence(double[] dArr, double[][][][] dArr2, int i, int i2) {
            double d = 0.0d;
            int i3 = 0;
            int alphabetLengthAt = (int) BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.contextPos);
            if (this.children != null) {
                int pow = (int) Math.pow(alphabetLengthAt, i2);
                while (i3 < this.children.length) {
                    d += this.children[i3].getWeightedKLDivergence(dArr, dArr2, i + (i3 * pow), i2 + 1);
                    i3++;
                }
            } else {
                double[] dArr3 = new double[this.pars.length];
                while (i3 < this.pars.length) {
                    int i4 = i3;
                    dArr3[i4] = dArr3[i4] + this.pars[i3].getValue() + this.pars[i3].getLogZ();
                    i3++;
                }
                double logSum = Normalisation.getLogSum(dArr3);
                double[] marginal = getMarginal(dArr, dArr2, i, i2);
                for (int i5 = 0; i5 < this.pars.length; i5++) {
                    double exp = Math.exp((this.pars[i5].getValue() + this.pars[i5].getLogZ()) - logSum);
                    if (exp > 0.0d) {
                        d += exp * Math.log(exp / marginal[i5]);
                    }
                }
                d *= getContextProbability();
            }
            return d;
        }

        private double[] getMarginal(double[] dArr, double[][][][] dArr2, int i, int i2) {
            int length;
            int pow;
            int alphabetLengthAt = (int) BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.contextPos);
            double[] dArr3 = new double[this.pars.length];
            for (int i3 = 0; i3 < dArr.length; i3++) {
                if (i2 < dArr2[i3].length) {
                    length = i2;
                    pow = i;
                } else {
                    length = dArr2[i3].length - 1;
                    pow = i % ((int) Math.pow(alphabetLengthAt, length));
                }
                for (int i4 = 0; i4 < this.pars.length; i4++) {
                    int i5 = i4;
                    dArr3[i5] = dArr3[i5] + (dArr[i3] * dArr2[i3][length][pow][i4]);
                }
            }
            return dArr3;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void drawKLDivergences(double[] dArr, double[] dArr2, double[][][][] dArr3, double d, int i, int i2) {
            int length;
            int pow;
            int alphabetLengthAt = (int) BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.contextPos);
            if (this.children != null) {
                int pow2 = (int) Math.pow(alphabetLengthAt, i2);
                for (int i3 = 0; i3 < this.children.length; i3++) {
                    this.children[i3].drawKLDivergences(dArr, dArr2, dArr3, d, i + (i3 * pow2), i2 + 1);
                }
                return;
            }
            double[] dArr4 = new double[this.pars.length];
            double[] dArr5 = new double[this.pars.length];
            double contextProbability = getContextProbability();
            DirichletMRGParams[] dirichletMRGParamsArr = new DirichletMRGParams[dArr2.length];
            for (int i4 = 0; i4 < dArr2.length; i4++) {
                if (i2 < dArr3[i4].length) {
                    length = i2;
                    pow = i;
                } else {
                    length = dArr3[i4].length - 1;
                    pow = i % ((int) Math.pow(alphabetLengthAt, length));
                }
                for (int i5 = 0; i5 < this.pars.length; i5++) {
                    int i6 = i5;
                    dArr5[i6] = dArr5[i6] + (dArr2[i4] * dArr3[i4][length][pow][i5]);
                    dArr4[i5] = d * contextProbability * dArr2[i4] * dArr3[i4][length][pow][i5];
                }
                dirichletMRGParamsArr[i4] = new DirichletMRGParams(dArr4);
            }
            double[] dArr6 = new double[this.pars.length];
            double[] dArr7 = new double[this.pars.length];
            for (int i7 = 0; i7 < dArr.length; i7++) {
                Arrays.fill(dArr7, 0.0d);
                for (int i8 = 0; i8 < dirichletMRGParamsArr.length; i8++) {
                    DirichletMRG.DEFAULT_INSTANCE.generate(dArr6, 0, dArr6.length, dirichletMRGParamsArr[i8]);
                    for (int i9 = 0; i9 < this.pars.length; i9++) {
                        int i10 = i9;
                        dArr7[i10] = dArr7[i10] + (dArr2[i8] * dArr6[i9]);
                    }
                }
                for (int i11 = 0; i11 < dArr6.length; i11++) {
                    if (dArr6[i11] > 0.0d) {
                        int i12 = i7;
                        dArr[i12] = dArr[i12] + (contextProbability * dArr7[i11] * Math.log(dArr7[i11] / dArr5[i11]));
                    }
                }
            }
        }

        private double getContextProbability() {
            if (this.children != null) {
                double d = 0.0d;
                for (int i = 0; i < this.children.length; i++) {
                    d += this.children[i].getContextProbability();
                }
                return d;
            }
            double[] dArr = new double[this.pars.length];
            for (int i2 = 0; i2 < this.pars.length; i2++) {
                dArr[i2] = this.pars[i2].getValue() + this.pars[i2].getLogZ();
            }
            double logSum = Normalisation.getLogSum(dArr);
            double[] dArr2 = new double[this.pars.length];
            for (int i3 = 0; i3 < this.pars.length; i3++) {
                dArr2[i3] = ((this.pars[i3].getValue() + this.pars[i3].getLogZ()) - logSum) + this.pars[i3].getLogT();
            }
            return Math.exp(Normalisation.getLogSum(dArr2));
        }

        private void findAndFill(double[][] dArr, int i) {
            fill(dArr, 0, 1, i);
        }

        private void fill(double[][] dArr, int i, int i2, int i3) {
            if (i3 <= 0) {
                fill(dArr[i]);
                return;
            }
            int i4 = i3 - 1;
            for (int i5 = 0; i5 < this.children.length; i5++) {
                this.children[i5].fill(dArr, i + (i5 * i2), i2 * dArr[0].length, i4);
            }
        }

        private void fill(double[] dArr) {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i].fill(dArr);
                }
                return;
            }
            if (!this.pars[this.pars.length - 1].isFree()) {
                for (int i2 = 0; i2 < this.pars.length - 1; i2++) {
                    this.pars[i2].setValue(Math.log(dArr[i2]) - Math.log(dArr[this.pars.length - 1]));
                }
                return;
            }
            if (dArr.length != this.pars.length) {
                throw new IndexOutOfBoundsException("Different number of values (" + dArr.length + ") than free parameters (" + this.pars.length + ").");
            }
            for (int i3 = 0; i3 < this.pars.length; i3++) {
                this.pars[i3].setValue(Math.log(dArr[i3]));
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void copy(TreeElement treeElement) {
            if (this.children != null) {
                if (treeElement.children == null) {
                    for (int i = 0; i < this.children.length; i++) {
                        this.children[i].copy(treeElement);
                    }
                    return;
                }
                if (this.children.length != treeElement.children.length) {
                    throw new IndexOutOfBoundsException("Different number of children.");
                }
                for (int i2 = 0; i2 < this.children.length; i2++) {
                    this.children[i2].copy(treeElement.children[i2]);
                }
                return;
            }
            if (treeElement.pars != null) {
                if (this.pars.length != treeElement.pars.length) {
                    throw new IndexOutOfBoundsException("Different number of parameters.");
                }
                for (int i3 = 0; i3 < this.pars.length; i3++) {
                    this.pars[i3].setValue(treeElement.pars[i3].getValue());
                }
                return;
            }
            double[] dArr = new double[this.pars.length];
            for (int i4 = 0; i4 < this.pars.length; i4++) {
                double logSum = treeElement.getLogSum(i4);
                dArr[i4] = logSum;
                this.pars[i4].setValue(logSum);
            }
            double logSum2 = Normalisation.getLogSum(dArr);
            for (int i5 = 0; i5 < this.pars.length; i5++) {
                this.pars[i5].setValue(this.pars[i5].getValue() - logSum2);
            }
        }

        private double getLogSum(int i) {
            if (this.children == null) {
                return getLogSumForLeaf(i);
            }
            double[] dArr = new double[this.children.length];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = this.children[i2].getLogSum(i);
            }
            return Normalisation.getLogSum(dArr);
        }

        private double getLogSumForLeaf(int i) {
            return this.pars[i].getValue() + this.pars[i].getLogT();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void initializeRandomly(double d) {
            if (this.pars == null) {
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i].initializeRandomly(d / BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.contextPos));
                }
                return;
            }
            if (d <= 0.0d) {
                d = BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.pars[0].getPosition());
            }
            double[] dArr = new double[this.pars.length];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = d / BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.pars[i2].getPosition());
            }
            double[] generate = DirichletMRG.DEFAULT_INSTANCE.generate(this.pars.length, new DirichletMRGParams(dArr));
            for (int i3 = 0; i3 < this.pars.length; i3++) {
                this.pars[i3].count = generate[i3];
            }
            normalizePlugInParameters();
            if (this.pars[this.pars.length - 1].isFree()) {
                return;
            }
            divideByUnfree();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double computeGammaNorm() {
            if (this.children != null) {
                double d = 0.0d;
                for (int i = 0; i < this.children.length; i++) {
                    d += this.children[i].computeGammaNorm();
                }
                return d;
            }
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i2 = 0; i2 < this.pars.length; i2++) {
                double pseudoCount = this.pars[i2].getPseudoCount();
                d3 += pseudoCount;
                d2 -= Gamma.logOfGamma(pseudoCount);
            }
            return d2 + Gamma.logOfGamma(d3);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double getProbFor(Sequence sequence, int i) {
            if (this.children == null) {
                return getContextProbability() * this.pars[sequence.discreteVal(sequence.getLength() - 1)].getExpValue();
            }
            if (i < sequence.getLength() - 1) {
                return this.children[sequence.discreteVal(i)].getProbFor(sequence, i + 1);
            }
            double d = 0.0d;
            for (int i2 = 0; i2 < this.children.length; i2++) {
                d += this.children[i2].getProbFor(sequence, i + 1);
            }
            return d;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int getNumberOfParameters() {
            if (this.children != null) {
                int i = 0;
                for (int i2 = 0; i2 < this.children.length; i2++) {
                    i += this.children[i2].getNumberOfParameters();
                }
                return i;
            }
            int i3 = 0;
            for (int i4 = 0; i4 < this.pars.length; i4++) {
                if (this.pars[i4].isFree()) {
                    i3++;
                }
            }
            return i3;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int getNumberOfSamplingSteps() {
            if (this.children == null) {
                return 1;
            }
            int i = 0;
            for (int i2 = 0; i2 < this.children.length; i2++) {
                i += this.children[i2].getNumberOfSamplingSteps();
            }
            return i;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int[] getParameterIndexesForSamplingStep(int i, int i2) {
            if (this.children == null) {
                int[] iArr = new int[this.pars.length - 1];
                for (int i3 = 0; i3 < iArr.length; i3++) {
                    iArr[i3] = i2 + i3;
                }
                return iArr;
            }
            for (int i4 = 0; i4 < this.children.length; i4++) {
                int numberOfSamplingSteps = this.children[i4].getNumberOfSamplingSteps();
                if (i < numberOfSamplingSteps) {
                    return this.children[i4].getParameterIndexesForSamplingStep(i, i2);
                }
                i -= numberOfSamplingSteps;
                i2 += this.children[i4].getNumberOfParameters();
            }
            return null;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void emitSymbol(int[] iArr) {
            if (this.children != null) {
                this.children[iArr[this.contextPos]].emitSymbol(iArr);
                return;
            }
            double[] dArr = new double[this.pars.length];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = this.pars[i].getValue() + this.pars[i].getLogZ();
            }
            Normalisation.logSumNormalisation(dArr);
            double nextDouble = BNDiffSMParameterTree.r.nextDouble();
            for (int i2 = 0; i2 < dArr.length; i2++) {
                if (nextDouble - dArr[i2] <= 0.0d) {
                    iArr[BNDiffSMParameterTree.this.pos] = i2;
                    return;
                }
                nextDouble -= dArr[i2];
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public byte getMaximalMarkovOrder(byte b) {
            return this.children != null ? this.children[0].getMaximalMarkovOrder((byte) (b + 1)) : b;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double getMaximumScore() {
            if (this.children != null) {
                throw new RuntimeException("Not implemented");
            }
            double d = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < this.pars.length; i++) {
                double value = this.pars[i].getValue();
                if (value > d) {
                    d = value;
                }
            }
            return d;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void appendHtmlToBuffer(StringBuffer stringBuffer, String str, NumberFormat numberFormat) {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; i++) {
                    this.children[i].appendHtmlToBuffer(stringBuffer, String.valueOf(str) + (str.length() == 0 ? TagValueParser.EMPTY_LINE_EOR : ", ") + "X_" + this.contextPos + " = " + BNDiffSMParameterTree.this.alphabet.getSymbol(this.contextPos, i), numberFormat);
                }
                return;
            }
            double[] dArr = new double[this.pars.length];
            for (int i2 = 0; i2 < this.pars.length; i2++) {
                dArr[i2] = this.pars[i2].getValue() + this.pars[i2].getLogZ();
            }
            double logSum = Normalisation.getLogSum(dArr);
            if (BNDiffSMParameterTree.this.getNumberOfParents() > 0) {
                stringBuffer.append("<tr><td>" + str + "</td>");
            } else {
                stringBuffer.append("<tr>");
            }
            for (int i3 = 0; i3 < this.pars.length; i3++) {
                stringBuffer.append("<td>" + numberFormat.format(Math.exp((this.pars[i3].getValue() + this.pars[i3].getLogZ()) - logSum)) + "</td>");
            }
            stringBuffer.append("</tr>");
        }

        /* synthetic */ TreeElement(BNDiffSMParameterTree bNDiffSMParameterTree, int i, AlphabetContainer alphabetContainer, TreeElement treeElement) {
            this(i, alphabetContainer);
        }
    }

    public BNDiffSMParameterTree(int i, int[] iArr, AlphabetContainer alphabetContainer, int i2, int[] iArr2) {
        this.pos = i;
        this.contextPoss = iArr;
        this.alphabet = alphabetContainer;
        this.firstParent = i2;
        this.firstChildren = iArr2;
        this.root = new TreeElement(this, 0, alphabetContainer, null);
    }

    public BNDiffSMParameterTree(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, "parameterTree");
        this.pos = ((Integer) XMLParser.extractObjectForTags(extractForTag, "pos", Integer.TYPE)).intValue();
        this.contextPoss = (int[]) XMLParser.extractObjectForTags(extractForTag, "contextPoss", int[].class);
        this.root = new TreeElement(XMLParser.extractForTag(extractForTag, "root"));
        this.alphabet = null;
        this.firstParent = ((Integer) XMLParser.extractObjectForTags(extractForTag, "firstParent", Integer.TYPE)).intValue();
        this.firstChildren = (int[]) XMLParser.extractObjectForTags(extractForTag, "firstChildren", int[].class);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setAlphabet(AlphabetContainer alphabetContainer) {
        this.alphabet = alphabetContainer;
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public BNDiffSMParameterTree m119clone() throws CloneNotSupportedException {
        BNDiffSMParameterTree bNDiffSMParameterTree = (BNDiffSMParameterTree) super.clone();
        bNDiffSMParameterTree.contextPoss = (int[]) this.contextPoss.clone();
        bNDiffSMParameterTree.cloneRoot();
        bNDiffSMParameterTree.firstChildren = (int[]) this.firstChildren.clone();
        return bNDiffSMParameterTree;
    }

    private void cloneRoot() throws CloneNotSupportedException {
        TreeElement treeElement = this.root;
        this.root = new TreeElement(this, this.root.contNum, this.alphabet, null);
        this.root.cloneRest(treeElement);
    }

    public String toString(NumberFormat numberFormat) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Probabilities at position " + this.pos + ":\n");
        this.root.appendToBuffer(stringBuffer, TagValueParser.EMPTY_LINE_EOR, numberFormat);
        return stringBuffer.toString();
    }

    public void insertProbs(double[] dArr) throws Exception {
        this.root.insertProbs(dArr);
    }

    public LinkedList<BNDiffSMParameter> linearizeParameters() {
        return this.root.linearizeParameters(new LinkedList());
    }

    public boolean isLeaf() {
        return this.firstChildren.length == 0;
    }

    public int getNumberOfParents() {
        return this.contextPoss.length;
    }

    public void print() {
        System.out.println("tree " + this.pos + ": ");
        this.root.print();
    }

    public BNDiffSMParameter getParameterFor(Sequence sequence, int i) {
        return this.root.getParameterFor(sequence, i);
    }

    public void setParameterFor(int i, int[][] iArr, BNDiffSMParameter bNDiffSMParameter) {
        this.root.setParameterFor(0, i, iArr, bNDiffSMParameter);
    }

    public void invalidateNormalizers() {
        this.root.invalidateNormalizers();
    }

    public double forward(BNDiffSMParameterTree[] bNDiffSMParameterTreeArr) throws RuntimeException {
        if (getNumberOfParents() > 0) {
            throw new RuntimeException("Forward can only be started at roots.");
        }
        return getLogZ(new int[0][2], bNDiffSMParameterTreeArr);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double getLogZ(int[][] iArr, BNDiffSMParameterTree[] bNDiffSMParameterTreeArr) throws RuntimeException {
        return this.root.getLogZ(iArr, new int[getNumberOfParents() + 1][2], bNDiffSMParameterTreeArr, 0);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double getLogT(int[][] iArr, BNDiffSMParameterTree[] bNDiffSMParameterTreeArr, int[][] iArr2) throws RuntimeException {
        return this.root.getLogT(iArr, this.firstParent > -1 ? new int[bNDiffSMParameterTreeArr[this.firstParent].contextPoss.length + 1][2] : new int[0][2], bNDiffSMParameterTreeArr, iArr2, 0);
    }

    public void backward(BNDiffSMParameterTree[] bNDiffSMParameterTreeArr, int[][] iArr) throws RuntimeException {
        if (!isLeaf()) {
            throw new RuntimeException("Backward can only be started at leaves.");
        }
        this.root.startBackward(new int[getNumberOfParents() + 1][2], bNDiffSMParameterTreeArr, iArr, 0);
    }

    public void addCount(Sequence sequence, int i, double d) {
        getParameterFor(sequence, i).addCount(d);
    }

    public void normalizePlugInParameters() {
        this.root.normalizePlugInParameters();
    }

    public void normalizeParameters() {
        this.root.normalizeParameters();
    }

    public void divideByUnfree() {
        this.root.divideByUnfree();
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer();
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.pos), "pos");
        XMLParser.appendObjectWithTags(stringBuffer, this.contextPoss, "contextPoss");
        XMLParser.appendObjectWithTags(stringBuffer, this.root, "root");
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.firstParent), "firstParent");
        XMLParser.appendObjectWithTags(stringBuffer, this.firstChildren, "firstChildren");
        XMLParser.addTags(stringBuffer, "parameterTree");
        return stringBuffer;
    }

    public int getFirstParent() {
        return this.firstParent;
    }

    public void drawKLDivergences(double d, double[] dArr, int i, int i2, double[][][] dArr2, double d2) {
        this.root.drawKLDivergences(d, dArr, i, i2, dArr2, d2, 0, 0);
    }

    public double getKLDivergence(double[][][] dArr) {
        return this.root.getWeightedKLDivergence(dArr, 0, 0);
    }

    public double getKLDivergence(double[] dArr, double[][][][] dArr2) {
        return this.root.getWeightedKLDivergence(dArr, dArr2, 0, 0);
    }

    public void drawKLDivergences(double[] dArr, double[] dArr2, double[][][][] dArr3, double d) {
        this.root.drawKLDivergences(dArr, dArr2, dArr3, d, 0, 0);
    }

    public void fill(double[] dArr, double[][][][] dArr2) {
        this.root.setNewParameters(dArr, dArr2, 0, 0);
    }

    public void copy(BNDiffSMParameterTree bNDiffSMParameterTree) {
        this.root.copy(bNDiffSMParameterTree.root);
    }

    public void initializeRandomly(double d) {
        this.root.initializeRandomly(d);
    }

    public Double computeGammaNorm() {
        return Double.valueOf(this.root.computeGammaNorm());
    }

    public double getProbFor(Sequence sequence) {
        return this.root.getProbFor(sequence, 0);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getNumberOfSamplingSteps() {
        return this.root.getNumberOfSamplingSteps();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getNumberOfParameters() {
        return this.root.getNumberOfParameters();
    }

    public int[] getParameterIndexesForSamplingStep(int i, int i2) {
        return this.root.getParameterIndexesForSamplingStep(i, i2);
    }

    public void emitSymbol(int[] iArr) {
        this.root.emitSymbol(iArr);
    }

    public byte getMaximalMarkovOrder() {
        return this.root.getMaximalMarkovOrder((byte) 0);
    }

    public double getMaximumScore() {
        return this.root.getMaximumScore();
    }

    public String toHtml(NumberFormat numberFormat) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("<p><strong>Probabilities at position " + this.pos + ":<strong><br/>");
        stringBuffer.append("<table border=\"1\"><tr>");
        if (getNumberOfParents() > 0) {
            stringBuffer.append("<th>context</th>");
        }
        for (int i = 0; i < this.alphabet.getAlphabetLengthAt(this.pos); i++) {
            stringBuffer.append("<th>" + ((DiscreteAlphabet) this.alphabet.getAlphabetAt(this.pos)).getSymbolAt(i) + "</th>");
        }
        stringBuffer.append("</tr>");
        this.root.appendHtmlToBuffer(stringBuffer, TagValueParser.EMPTY_LINE_EOR, numberFormat);
        stringBuffer.append("</table></p>");
        return stringBuffer.toString();
    }
}
