package de.jstacs.sequenceScores.statisticalModels.differentiable.continuous;

import cern.colt.matrix.impl.AbstractFormatter;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.random.RandomNumberGenerator;
import htsjdk.variant.vcf.VCFConstants;
import java.text.NumberFormat;
import java.util.Arrays;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/continuous/GaussianNetwork.class */
public class GaussianNetwork extends AbstractDifferentiableStatisticalModel {
    private double[] mu;
    private double[] lambda;
    private double[][] bij;
    private int[][] structure;
    private int[] boff;
    private double ess;

    public GaussianNetwork(int[][] iArr) throws CloneNotSupportedException {
        this(new AlphabetContainer(new ContinuousAlphabet()), iArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v12, types: [double[], double[][]] */
    public GaussianNetwork(AlphabetContainer alphabetContainer, int[][] iArr) throws CloneNotSupportedException {
        super(alphabetContainer, iArr.length);
        this.structure = (int[][]) ArrayHandler.clone(iArr);
        this.mu = new double[iArr.length];
        this.lambda = new double[iArr.length];
        this.bij = new double[iArr.length];
        int length = this.mu.length + this.lambda.length;
        this.boff = new int[this.bij.length];
        for (int i = 0; i < iArr.length; i++) {
            this.boff[i] = length;
            this.bij[i] = new double[iArr[i].length];
            length += this.bij[i].length;
        }
    }

    public GaussianNetwork(AlphabetContainer alphabetContainer, int[][] iArr, double d) throws CloneNotSupportedException {
        this(alphabetContainer, iArr);
        this.ess = d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone */
    public GaussianNetwork mo106clone() throws CloneNotSupportedException {
        GaussianNetwork gaussianNetwork = (GaussianNetwork) super.mo106clone();
        gaussianNetwork.structure = (int[][]) this.structure.clone();
        for (int i = 0; i < gaussianNetwork.structure.length; i++) {
            gaussianNetwork.structure[i] = (int[]) this.structure[i].clone();
        }
        gaussianNetwork.mu = (double[]) this.mu.clone();
        gaussianNetwork.lambda = (double[]) this.lambda.clone();
        gaussianNetwork.bij = (double[][]) this.bij.clone();
        for (int i2 = 0; i2 < gaussianNetwork.bij.length; i2++) {
            gaussianNetwork.bij[i2] = (double[]) this.bij[i2].clone();
        }
        gaussianNetwork.boff = (int[]) this.boff.clone();
        return gaussianNetwork;
    }

    public GaussianNetwork(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int i) {
        return 0;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogNormalizationConstant() {
        return 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogPartialNormalizationConstant(int i) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        return 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) throws Exception {
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        return this.ess;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        initializeFunctionRandomly(z);
        DataSet dataSet = dataSetArr[i];
        Arrays.fill(this.mu, 0.0d);
        Arrays.fill(this.lambda, 0.0d);
        double d = 0.0d;
        for (int i2 = 0; i2 < dataSet.getNumberOfElements(); i2++) {
            double d2 = (dArr == null || dArr[i] == null) ? 1.0d : dArr[i][i2];
            Sequence elementAt = dataSet.getElementAt(i2);
            for (int i3 = 0; i3 < elementAt.getLength(); i3++) {
                double continuousVal = elementAt.continuousVal(i3);
                double[] dArr2 = this.mu;
                int i4 = i3;
                dArr2[i4] = dArr2[i4] + (continuousVal * d2);
                double[] dArr3 = this.lambda;
                int i5 = i3;
                dArr3[i5] = dArr3[i5] + (continuousVal * continuousVal * d2);
            }
            d += d2;
        }
        for (int i6 = 0; i6 < this.mu.length; i6++) {
            double[] dArr4 = this.mu;
            int i7 = i6;
            dArr4[i7] = dArr4[i7] / d;
            double[] dArr5 = this.lambda;
            int i8 = i6;
            dArr5[i8] = dArr5[i8] / d;
            this.lambda[i6] = this.lambda[i6] - (this.mu[i6] * this.mu[i6]);
            this.lambda[i6] = -Math.log(this.lambda[i6]);
        }
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        RandomNumberGenerator randomNumberGenerator = new RandomNumberGenerator();
        for (int i = 0; i < this.mu.length; i++) {
            this.mu[i] = randomNumberGenerator.nextGaussian();
            this.lambda[i] = Math.log(0.1d);
            for (int i2 = 0; i2 < this.bij[i].length; i2++) {
                this.bij[i][i2] = randomNumberGenerator.nextGaussian();
            }
        }
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.structure.length; i2++) {
            double d2 = this.mu[i2];
            for (int i3 = 0; i3 < this.structure[i2].length; i3++) {
                d2 += this.bij[i2][i3] * (sequence.continuousVal(i + this.structure[i2][i3]) - this.mu[this.structure[i2][i3]]);
            }
            double continuousVal = sequence.continuousVal(i + i2) - d2;
            double exp = Math.exp(this.lambda[i2]);
            d += ((0.5d * this.lambda[i2]) - (0.5d * Math.log(6.283185307179586d))) - (((exp / 2.0d) * continuousVal) * continuousVal);
            intList.add(this.mu.length + i2);
            doubleList.add(0.5d - (((exp / 2.0d) * continuousVal) * continuousVal));
            intList.add(i2);
            doubleList.add(exp * continuousVal);
            for (int i4 = 0; i4 < this.structure[i2].length; i4++) {
                intList.add(this.structure[i2][i4]);
                doubleList.add((-exp) * continuousVal * this.bij[i2][i4]);
            }
            for (int i5 = 0; i5 < this.structure[i2].length; i5++) {
                intList.add(this.boff[i2] + i5);
                doubleList.add(exp * continuousVal * (sequence.continuousVal(i + this.structure[i2][i5]) - this.mu[this.structure[i2][i5]]));
            }
        }
        doubleList.length();
        return d;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        return this.boff[this.boff.length - 1] + this.bij[this.bij.length - 1].length;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        double[] dArr = new double[getNumberOfParameters()];
        System.arraycopy(this.mu, 0, dArr, 0, this.mu.length);
        int length = 0 + this.mu.length;
        System.arraycopy(this.lambda, 0, dArr, length, this.lambda.length);
        int length2 = length + this.lambda.length;
        for (int i = 0; i < this.bij.length; i++) {
            System.arraycopy(this.bij[i], 0, dArr, length2, this.bij[i].length);
            length2 += this.bij[i].length;
        }
        return dArr;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        System.arraycopy(dArr, i, this.mu, 0, this.mu.length);
        int length = i + this.mu.length;
        System.arraycopy(dArr, length, this.lambda, 0, this.lambda.length);
        int length2 = length + this.lambda.length;
        for (int i2 = 0; i2 < this.bij.length; i2++) {
            System.arraycopy(dArr, length2, this.bij[i2], 0, this.bij[i2].length);
            length2 += this.bij[i2].length;
        }
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return "GaussianNetwork";
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.structure.length; i2++) {
            double d2 = this.mu[i2];
            for (int i3 = 0; i3 < this.structure[i2].length; i3++) {
                d2 += this.bij[i2][i3] * (sequence.continuousVal(i + this.structure[i2][i3]) - this.mu[this.structure[i2][i3]]);
            }
            double continuousVal = sequence.continuousVal(i + i2) - d2;
            d += ((0.5d * this.lambda[i2]) - (0.5d * Math.log(6.283185307179586d))) - (((Math.exp(this.lambda[i2]) / 2.0d) * continuousVal) * continuousVal);
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return true;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i = 0; i < this.mu.length; i++) {
            stringBuffer.append(String.valueOf(i) + ": dnorm(" + this.mu[i] + VCFConstants.INFO_FIELD_ARRAY_SEPARATOR + (1.0d / Math.exp(this.lambda[i])) + ")\n");
        }
        for (int i2 = 0; i2 < this.bij.length; i2++) {
            stringBuffer.append("b_" + i2 + ": " + Arrays.toString(this.bij[i2]) + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        }
        return stringBuffer.toString();
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer();
        XMLParser.appendObjectWithTags(stringBuffer, this.bij, "bij");
        XMLParser.appendObjectWithTags(stringBuffer, this.boff, "boff");
        XMLParser.appendObjectWithTags(stringBuffer, this.lambda, "lambda");
        XMLParser.appendObjectWithTags(stringBuffer, this.mu, "mu");
        XMLParser.appendObjectWithTags(stringBuffer, this.structure, "structure");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.ess), "ess");
        XMLParser.addTags(stringBuffer, "GaussNet");
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, "GaussNet");
        this.bij = (double[][]) XMLParser.extractObjectForTags(extractForTag, "bij");
        this.boff = (int[]) XMLParser.extractObjectForTags(extractForTag, "boff");
        this.lambda = (double[]) XMLParser.extractObjectForTags(extractForTag, "lambda");
        this.mu = (double[]) XMLParser.extractObjectForTags(extractForTag, "mu");
        this.structure = (int[][]) XMLParser.extractObjectForTags(extractForTag, "structure");
        try {
            this.ess = ((Double) XMLParser.extractObjectForTags(extractForTag, "ess")).doubleValue();
        } catch (NonParsableException e) {
            this.ess = 0.0d;
        }
        this.alphabets = new AlphabetContainer(new ContinuousAlphabet());
        this.length = this.mu.length;
    }
}
