package de.jstacs.sequenceScores.statisticalModels.trainable.continuous;

import cern.colt.matrix.impl.AbstractFormatter;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.MultiDimensionalSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel;
import de.jtem.numericalMethods.algebra.linear.Determinant;
import de.jtem.numericalMethods.algebra.linear.Inversion;
import java.text.NumberFormat;
import java.util.Arrays;
import projects.dispom.DispomParameterSet;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/continuous/MultivariateGaussian.class */
public class MultivariateGaussian extends AbstractTrainableStatisticalModel {
    private int dimension;
    private boolean dimensionAlongPositions;
    private double[] mean;
    private double[][] precMat;
    private double[] tempValues;
    private double[][] precTemp;
    private boolean isInitialized;
    private double[] priorMean;
    private double[][] priorPrecMat;
    private double essMu;
    private double essPrec;
    private boolean meanFixed;
    private boolean precisionFixed;

    public MultivariateGaussian(int i, boolean z) throws CloneNotSupportedException {
        this(i, z, new double[i], getPrecMat(i, 0.0d), 0.0d, i, false, false);
    }

    public MultivariateGaussian(int i, boolean z, double[] dArr, double d, double d2) throws CloneNotSupportedException {
        this(i, z, dArr, getPrecMat(i, d), d2, d2, false, false);
    }

    private static double[][] getPrecMat(int i, double d) {
        double[][] dArr = new double[i][i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2][i2] = d;
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public MultivariateGaussian(int i, boolean z, double[] dArr, double[][] dArr2, double d, double d2, boolean z2, boolean z3) throws CloneNotSupportedException {
        super(new AlphabetContainer(new ContinuousAlphabet()), z ? i : 0);
        this.dimension = i;
        this.dimensionAlongPositions = z;
        this.mean = new double[i];
        this.tempValues = new double[i];
        this.precMat = new double[i][i];
        this.precTemp = new double[i][i];
        this.priorMean = (double[]) dArr.clone();
        this.priorPrecMat = (double[][]) ArrayHandler.clone(dArr2);
        this.essMu = d;
        this.essPrec = d2;
        this.meanFixed = z2;
        this.precisionFixed = z3;
    }

    public MultivariateGaussian(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v17, types: [java.lang.Cloneable[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v5, types: [java.lang.Cloneable[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v9, types: [java.lang.Cloneable[], double[][]] */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone */
    public MultivariateGaussian mo130clone() throws CloneNotSupportedException {
        MultivariateGaussian multivariateGaussian = (MultivariateGaussian) super.mo130clone();
        multivariateGaussian.mean = (double[]) this.mean.clone();
        multivariateGaussian.precMat = (double[][]) ArrayHandler.clone(this.precMat);
        multivariateGaussian.precTemp = (double[][]) ArrayHandler.clone(this.precTemp);
        multivariateGaussian.priorMean = (double[]) this.priorMean.clone();
        multivariateGaussian.priorPrecMat = (double[][]) ArrayHandler.clone(this.priorPrecMat);
        multivariateGaussian.tempValues = (double[]) this.tempValues.clone();
        return multivariateGaussian;
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer();
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.dimension), "dimension");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.dimensionAlongPositions), "dimensionAlongPositions");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.essMu), "essMu");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.essPrec), "essPrec");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.isInitialized), "isInitialized");
        XMLParser.appendObjectWithTags(stringBuffer, this.mean, DispomParameterSet.MEAN);
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.meanFixed), "meanFixed");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.precisionFixed), "precisionFixed");
        XMLParser.appendObjectWithTags(stringBuffer, this.precMat, "precMat");
        XMLParser.appendObjectWithTags(stringBuffer, this.precTemp, "precTemp");
        XMLParser.appendObjectWithTags(stringBuffer, this.priorMean, "priorMean");
        XMLParser.appendObjectWithTags(stringBuffer, this.priorPrecMat, "priorPrecMat");
        XMLParser.addTags(stringBuffer, "MultivariateGaussianDiffSM");
        return stringBuffer;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v95, types: [java.lang.Cloneable[], double[][]] */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel
    public void train(DataSet dataSet, double[] dArr) throws Exception {
        if (this.meanFixed) {
            this.mean = (double[]) this.priorMean.clone();
        } else {
            double d = this.essMu;
            for (int i = 0; i < this.mean.length; i++) {
                this.mean[i] = this.essMu * this.priorMean[i];
            }
            for (int i2 = 0; i2 < dataSet.getNumberOfElements(); i2++) {
                Sequence elementAt = dataSet.getElementAt(i2);
                double d2 = dArr == null ? 1.0d : dArr[i2];
                if (this.dimensionAlongPositions) {
                    for (int i3 = 0; i3 < elementAt.getLength(); i3++) {
                        double[] dArr2 = this.mean;
                        int i4 = i3;
                        dArr2[i4] = dArr2[i4] + (elementAt.continuousVal(i3) * d2);
                    }
                    d += d2;
                } else if (this.dimension == 1) {
                    for (int i5 = 0; i5 < elementAt.getLength(); i5++) {
                        double[] dArr3 = this.mean;
                        dArr3[0] = dArr3[0] + (elementAt.continuousVal(i5) * d2);
                        d += d2;
                    }
                } else {
                    if (((MultiDimensionalSequence) elementAt).getNumberOfSequences() != this.dimension) {
                        throw new Exception();
                    }
                    for (int i6 = 0; i6 < elementAt.getLength(); i6++) {
                        ((MultiDimensionalSequence) elementAt).fillContainer(this.tempValues, i6);
                        for (int i7 = 0; i7 < this.tempValues.length; i7++) {
                            double[] dArr4 = this.mean;
                            int i8 = i7;
                            dArr4[i8] = dArr4[i8] + (this.tempValues[i7] * d2);
                        }
                        d += d2;
                    }
                }
            }
            for (int i9 = 0; i9 < this.mean.length; i9++) {
                double[] dArr5 = this.mean;
                int i10 = i9;
                dArr5[i10] = dArr5[i10] / d;
            }
        }
        if (this.precisionFixed) {
            this.precMat = (double[][]) ArrayHandler.clone(this.priorPrecMat);
        } else {
            for (int i11 = 0; i11 < this.precMat.length; i11++) {
                Arrays.fill(this.precMat[i11], 0.0d);
            }
            double d3 = 0.0d;
            for (int i12 = 0; i12 < dataSet.getNumberOfElements(); i12++) {
                Sequence elementAt2 = dataSet.getElementAt(i12);
                double d4 = dArr == null ? 1.0d : dArr[i12];
                if (this.dimensionAlongPositions) {
                    for (int i13 = 0; i13 < elementAt2.getLength(); i13++) {
                        for (int i14 = 0; i14 < elementAt2.getLength(); i14++) {
                            double[] dArr6 = this.precMat[i13];
                            int i15 = i14;
                            dArr6[i15] = dArr6[i15] + ((elementAt2.continuousVal(i13) - this.mean[i13]) * (elementAt2.continuousVal(i14) - this.mean[i14]) * d4);
                        }
                    }
                    d3 += d4;
                } else if (this.dimension == 1) {
                    for (int i16 = 0; i16 < elementAt2.getLength(); i16++) {
                        double[] dArr7 = this.precMat[0];
                        dArr7[0] = dArr7[0] + ((elementAt2.continuousVal(i16) - this.mean[0]) * (elementAt2.continuousVal(i16) - this.mean[0]) * d4);
                        d3 += d4;
                    }
                } else {
                    if (((MultiDimensionalSequence) elementAt2).getNumberOfSequences() != this.dimension) {
                        throw new Exception();
                    }
                    for (int i17 = 0; i17 < elementAt2.getLength(); i17++) {
                        ((MultiDimensionalSequence) elementAt2).fillContainer(this.tempValues, i17);
                        for (int i18 = 0; i18 < this.tempValues.length; i18++) {
                            for (int i19 = 0; i19 < this.tempValues.length; i19++) {
                                double[] dArr8 = this.precMat[i18];
                                int i20 = i19;
                                dArr8[i20] = dArr8[i20] + ((this.tempValues[i18] - this.mean[i18]) * (this.tempValues[i19] - this.mean[i19]) * d4);
                            }
                        }
                        d3 += d4;
                    }
                }
            }
            for (int i21 = 0; i21 < this.mean.length; i21++) {
                for (int i22 = 0; i22 < this.mean.length; i22++) {
                    this.precTemp[i21][i22] = (this.mean[i21] - this.priorMean[i21]) * (this.mean[i22] - this.priorMean[i22]);
                }
            }
            for (int i23 = 0; i23 < this.precMat.length; i23++) {
                for (int i24 = 0; i24 < this.precMat[i23].length; i24++) {
                    this.precMat[i23][i24] = ((this.priorPrecMat[i23][i24] + this.precMat[i23][i24]) + (this.essMu * this.precTemp[i23][i24])) / ((this.essPrec - this.dimension) + d3);
                }
            }
            Inversion.compute(this.precMat, this.precTemp);
            double[][] dArr9 = this.precMat;
            this.precMat = this.precTemp;
            this.precTemp = dArr9;
        }
        this.isInitialized = true;
    }

    public double getLogProbFor(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                d -= ((dArr[i] - this.mean[i]) * this.precMat[i][i2]) * (dArr[i2] - this.mean[i2]);
            }
        }
        return (d + (Math.log(Determinant.compute(this.precMat)) - (this.dimension * Math.log(6.283185307179586d)))) * 0.5d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogProbFor(Sequence sequence, int i, int i2) throws Exception {
        if (this.dimensionAlongPositions) {
            if ((i2 - i) + 1 != this.dimension) {
                throw new Exception(String.valueOf(i2) + " - " + i + " != " + this.dimension);
            }
            for (int i3 = i; i3 <= i2; i3++) {
                this.tempValues[i3 - i] = sequence.continuousVal(i3);
            }
            return getLogProbFor(this.tempValues);
        }
        double d = 0.0d;
        if (this.dimension == 1) {
            for (int i4 = i; i4 <= i2; i4++) {
                this.tempValues[0] = sequence.continuousVal(i4);
                d += getLogProbFor(this.tempValues);
            }
        } else {
            if (((MultiDimensionalSequence) sequence).getNumberOfSequences() != this.dimension) {
                throw new Exception();
            }
            for (int i5 = i; i5 <= i2; i5++) {
                ((MultiDimensionalSequence) sequence).fillContainer(this.tempValues, i5);
                d += getLogProbFor(this.tempValues);
            }
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() throws Exception {
        double d = 0.0d;
        if (!this.meanFixed) {
            for (int i = 0; i < this.mean.length; i++) {
                for (int i2 = 0; i2 < this.mean.length; i2++) {
                    d += (this.priorMean[i] - this.mean[i]) * this.precMat[i][i2] * (this.priorMean[i2] - this.mean[i2]);
                }
            }
            d *= (-this.essMu) / 2.0d;
        }
        double log = Math.log(Determinant.compute(this.precMat));
        if (!this.precisionFixed) {
            d += (((this.essPrec - this.dimension) - 1.0d) / 2.0d) * log;
            for (int i3 = 0; i3 < this.mean.length; i3++) {
                for (int i4 = 0; i4 < this.mean.length; i4++) {
                    d -= (0.5d * this.priorPrecMat[i3][i4]) * this.precMat[i4][i3];
                }
            }
        }
        if (!this.meanFixed) {
            d += 0.5d * log;
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return null;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        return null;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, "MultivariateGaussianDiffSM");
        this.dimension = ((Integer) XMLParser.extractObjectForTags(extractForTag, "dimension", Integer.TYPE)).intValue();
        this.dimensionAlongPositions = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "dimensionAlongPositions", Boolean.TYPE)).booleanValue();
        this.essMu = ((Double) XMLParser.extractObjectForTags(extractForTag, "essMu", Double.TYPE)).doubleValue();
        this.essPrec = ((Double) XMLParser.extractObjectForTags(extractForTag, "essPrec", Double.TYPE)).doubleValue();
        this.isInitialized = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "isInitialized", Boolean.TYPE)).booleanValue();
        this.mean = (double[]) XMLParser.extractObjectForTags(extractForTag, DispomParameterSet.MEAN);
        this.meanFixed = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "meanFixed", Boolean.TYPE)).booleanValue();
        this.precisionFixed = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "precisionFixed", Boolean.TYPE)).booleanValue();
        this.precMat = (double[][]) XMLParser.extractObjectForTags(extractForTag, "precMat");
        this.precTemp = (double[][]) XMLParser.extractObjectForTags(extractForTag, "precTemp");
        this.priorMean = (double[]) XMLParser.extractObjectForTags(extractForTag, "priorMean");
        this.priorPrecMat = (double[][]) XMLParser.extractObjectForTags(extractForTag, "priorPrecMat");
        this.tempValues = new double[this.dimension];
        this.alphabets = new AlphabetContainer(new ContinuousAlphabet());
        this.length = this.dimensionAlongPositions ? this.dimension : 0;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        String str = String.valueOf(Arrays.toString(this.mean)) + AbstractFormatter.DEFAULT_SLICE_SEPARATOR;
        for (int i = 0; i < this.precMat.length; i++) {
            str = String.valueOf(str) + Arrays.toString(this.precMat[i]);
        }
        return str;
    }
}
