package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.continuous;

import cern.colt.matrix.impl.AbstractFormatter;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.MultiDimensionalSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import de.jstacs.utils.random.RandomNumberGenerator;
import de.jtem.numericalMethods.algebra.group.Permutation;
import de.jtem.numericalMethods.algebra.linear.Determinant;
import de.jtem.numericalMethods.algebra.linear.Inversion;
import java.text.NumberFormat;
import java.util.Arrays;
import javax.naming.OperationNotSupportedException;
import umontreal.iro.lecuyer.randvar.NormalGen;
import umontreal.iro.lecuyer.randvarmulti.MultinormalCholeskyGen;
import umontreal.iro.lecuyer.rng.LFSR258;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/states/emissions/continuous/FixedCorrelationGaussianEmission.class */
public class FixedCorrelationGaussianEmission implements Emission {
    private int dimension;
    private AlphabetContainer con;
    private double[] mean;
    private double[][] prePrec;
    private double[] sds;
    private double logDet;
    private double[] tempValues;
    private double[] priorMean;
    private double[][] corrMat;
    private double[][] invCorrMat;
    private double[][] precTemp;
    private double[][] fullPrecMat;
    private double[] priorPrec;
    private double essMu;
    private double essPrec;
    private double[] meanStat;
    private double[][] prcStat;
    private double numStat;
    private boolean meanFixed;
    private double decay;

    /* JADX WARN: Multi-variable type inference failed */
    public FixedCorrelationGaussianEmission(int i, double[] dArr, double[] dArr2, double d, double d2, double[][] dArr3, boolean z) throws CloneNotSupportedException {
        this.dimension = i;
        this.priorMean = (double[]) dArr.clone();
        this.priorPrec = (double[]) dArr2.clone();
        this.essMu = d;
        this.essPrec = d2;
        this.corrMat = (double[][]) ArrayHandler.clone(dArr3);
        this.invCorrMat = new double[i][i];
        Inversion.compute(dArr3, this.invCorrMat);
        this.meanFixed = z;
        this.mean = new double[i];
        this.prePrec = new double[i][i];
        this.fullPrecMat = new double[i][i];
        this.precTemp = new double[i][i];
        this.sds = new double[i];
        this.tempValues = new double[i];
        this.meanStat = new double[i];
        this.prcStat = new double[i][i];
        this.con = new AlphabetContainer(new ContinuousAlphabet());
        this.decay = 1.0d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v13, types: [java.lang.Cloneable[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v17, types: [java.lang.Cloneable[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v21, types: [java.lang.Cloneable[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v29, types: [java.lang.Cloneable[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v33, types: [java.lang.Cloneable[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v5, types: [java.lang.Cloneable[], double[][]] */
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public FixedCorrelationGaussianEmission m158clone() throws CloneNotSupportedException {
        FixedCorrelationGaussianEmission fixedCorrelationGaussianEmission = (FixedCorrelationGaussianEmission) super.clone();
        fixedCorrelationGaussianEmission.mean = (double[]) this.mean.clone();
        fixedCorrelationGaussianEmission.prePrec = (double[][]) ArrayHandler.clone(this.prePrec);
        fixedCorrelationGaussianEmission.meanStat = (double[]) this.meanStat.clone();
        fixedCorrelationGaussianEmission.prcStat = (double[][]) ArrayHandler.clone(this.prcStat);
        fixedCorrelationGaussianEmission.precTemp = (double[][]) ArrayHandler.clone(this.precTemp);
        fixedCorrelationGaussianEmission.fullPrecMat = (double[][]) ArrayHandler.clone(this.fullPrecMat);
        fixedCorrelationGaussianEmission.sds = (double[]) this.sds.clone();
        fixedCorrelationGaussianEmission.corrMat = (double[][]) ArrayHandler.clone(this.corrMat);
        fixedCorrelationGaussianEmission.invCorrMat = (double[][]) ArrayHandler.clone(this.invCorrMat);
        fixedCorrelationGaussianEmission.priorMean = (double[]) this.priorMean.clone();
        fixedCorrelationGaussianEmission.priorPrec = (double[]) this.priorPrec.clone();
        fixedCorrelationGaussianEmission.tempValues = (double[]) this.tempValues.clone();
        fixedCorrelationGaussianEmission.decay = 1.0d;
        return fixedCorrelationGaussianEmission;
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        return null;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public AlphabetContainer getAlphabetContainer() {
        return this.con;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void initializeFunctionRandomly() {
        RandomNumberGenerator randomNumberGenerator = new RandomNumberGenerator();
        for (int i = 0; i < this.prePrec.length; i++) {
            this.sds[i] = randomNumberGenerator.nextGamma(this.essPrec, 1.0d / this.priorPrec[i]);
            this.prePrec[i][i] = 1.0d / this.sds[i];
            this.sds[i] = Math.sqrt(1.0d / this.sds[i]);
        }
        for (int i2 = 0; i2 < this.prePrec.length; i2++) {
            for (int i3 = 0; i3 < i2; i3++) {
                this.prePrec[i2][i3] = this.corrMat[i2][i3] * this.sds[i2] * this.sds[i3];
                this.prePrec[i3][i2] = this.prePrec[i2][i3];
            }
        }
        for (int i4 = 0; i4 < this.prePrec.length; i4++) {
            System.out.println(Arrays.toString(this.prePrec[i4]));
        }
        new MultinormalCholeskyGen(new NormalGen(new LFSR258()), this.priorMean, this.prePrec).nextPoint(this.mean);
        System.out.println("mean: " + Arrays.toString(this.mean));
        Inversion.compute(this.prePrec, this.precTemp);
        double[][] dArr = this.prePrec;
        this.prePrec = this.precTemp;
        this.precTemp = dArr;
        for (int i5 = 0; i5 < this.prePrec.length; i5++) {
            System.out.println(Arrays.toString(this.prePrec[i5]));
        }
        if (this.meanFixed) {
            this.mean = (double[]) this.priorMean.clone();
        }
        this.logDet = Math.log(Determinant.compute(this.prePrec));
    }

    private double getLogProbFor(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                d -= ((dArr[i] - this.mean[i]) * this.prePrec[i][i2]) * (dArr[i2] - this.mean[i2]);
            }
        }
        return (d + (this.logDet - (this.dimension * Math.log(6.283185307179586d)))) * 0.5d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public double getLogProbFor(boolean z, int i, int i2, Sequence sequence) throws OperationNotSupportedException {
        double d = 0.0d;
        if (this.dimension == 1) {
            for (int i3 = i; i3 <= i2; i3++) {
                this.tempValues[0] = sequence.continuousVal(i3);
                d += getLogProbFor(this.tempValues);
            }
        } else {
            if (((MultiDimensionalSequence) sequence).getNumberOfSequences() != this.dimension) {
                throw new OperationNotSupportedException();
            }
            for (int i4 = i; i4 <= i2; i4++) {
                ((MultiDimensionalSequence) sequence).fillContainer(this.tempValues, i4);
                d += getLogProbFor(this.tempValues);
            }
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public double getLogPriorTerm() {
        double d = 0.0d;
        if (!this.meanFixed) {
            for (int i = 0; i < this.mean.length; i++) {
                for (int i2 = 0; i2 < this.mean.length; i2++) {
                    d += (this.priorMean[i] - this.mean[i]) * this.prePrec[i][i2] * (this.priorMean[i2] - this.mean[i2]);
                }
            }
            d *= (-1.0d) / (2.0d * this.essMu);
        }
        for (int i3 = 0; i3 < this.sds.length; i3++) {
            d -= (((2.0d * this.essPrec) + 1.0d) * Math.log(this.sds[i3])) + (this.priorPrec[i3] / (this.sds[i3] * this.sds[i3]));
        }
        if (!this.meanFixed) {
            d += 0.5d * this.logDet;
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void resetStatistic() {
        Arrays.fill(this.meanStat, 0.0d);
        for (int i = 0; i < this.prcStat.length; i++) {
            Arrays.fill(this.prcStat[i], 0.0d);
        }
        this.numStat = 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void addToStatistic(boolean z, int i, int i2, double d, Sequence sequence) throws OperationNotSupportedException {
        if (this.dimension == 1) {
            for (int i3 = i; i3 <= i2; i3++) {
                double[] dArr = this.meanStat;
                dArr[0] = dArr[0] + (sequence.continuousVal(i3) * d);
                double[] dArr2 = this.prcStat[0];
                dArr2[0] = dArr2[0] + (sequence.continuousVal(i3) * sequence.continuousVal(i3) * d);
                this.numStat += d;
            }
            return;
        }
        if (((MultiDimensionalSequence) sequence).getNumberOfSequences() != this.dimension) {
            throw new OperationNotSupportedException();
        }
        for (int i4 = i; i4 <= i2; i4++) {
            ((MultiDimensionalSequence) sequence).fillContainer(this.tempValues, i4);
            for (int i5 = 0; i5 < this.tempValues.length; i5++) {
                double[] dArr3 = this.meanStat;
                int i6 = i5;
                dArr3[i6] = dArr3[i6] + (this.tempValues[i5] * d);
                for (int i7 = 0; i7 < this.prcStat[i5].length; i7++) {
                    double[] dArr4 = this.prcStat[i5];
                    int i8 = i7;
                    dArr4[i8] = dArr4[i8] + (this.tempValues[i5] * this.tempValues[i7] * d);
                }
            }
            this.numStat += d;
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void joinStatistics(Emission... emissionArr) {
        for (int i = 0; i < emissionArr.length; i++) {
            if (emissionArr[i] != this) {
                for (int i2 = 0; i2 < this.meanStat.length; i2++) {
                    double[] dArr = this.meanStat;
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + ((FixedCorrelationGaussianEmission) emissionArr[i]).meanStat[i2];
                    for (int i4 = 0; i4 < this.prcStat[i2].length; i4++) {
                        double[] dArr2 = this.prcStat[i2];
                        int i5 = i4;
                        dArr2[i5] = dArr2[i5] + ((FixedCorrelationGaussianEmission) emissionArr[i]).prcStat[i2][i4];
                    }
                }
                this.numStat += ((FixedCorrelationGaussianEmission) emissionArr[i]).numStat;
            }
        }
        for (int i6 = 0; i6 < emissionArr.length; i6++) {
            if (emissionArr[i6] != this) {
                for (int i7 = 0; i7 < this.meanStat.length; i7++) {
                    ((FixedCorrelationGaussianEmission) emissionArr[i6]).meanStat[i7] = this.meanStat[i7];
                    for (int i8 = 0; i8 < this.prcStat[i7].length; i8++) {
                        ((FixedCorrelationGaussianEmission) emissionArr[i6]).prcStat[i7][i8] = this.prcStat[i7][i8];
                    }
                }
                ((FixedCorrelationGaussianEmission) emissionArr[i6]).numStat = this.numStat;
            }
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void estimateFromStatistic() {
        if (this.meanFixed) {
            this.mean = (double[]) this.priorMean.clone();
        } else {
            for (int i = 0; i < this.mean.length; i++) {
                this.mean[i] = (this.essMu * this.priorMean[i]) + this.meanStat[i];
                double[] dArr = this.mean;
                int i2 = i;
                dArr[i2] = dArr[i2] / (this.numStat + this.essMu);
            }
        }
        for (int i3 = 0; i3 < this.mean.length; i3++) {
            for (int i4 = 0; i4 < this.mean.length; i4++) {
                if (this.meanFixed) {
                    this.precTemp[i3][i4] = 0.0d;
                } else {
                    this.precTemp[i3][i4] = (this.mean[i3] - this.priorMean[i3]) * (this.mean[i4] - this.priorMean[i4]);
                }
            }
        }
        for (int i5 = 0; i5 < this.sds.length; i5++) {
            this.sds[i5] = (this.prcStat[i5][i5] - ((2.0d * this.mean[i5]) * this.meanStat[i5])) + (this.numStat * this.mean[i5] * this.mean[i5]);
            this.sds[i5] = ((this.invCorrMat[i5][i5] * (this.sds[i5] + (this.essMu * this.precTemp[i5][i5]))) + (2.0d * this.priorPrec[i5])) / (((this.numStat + (2.0d * this.essPrec)) + 1.0d) + (this.meanFixed ? 0 : 1));
            this.sds[i5] = Math.sqrt(this.sds[i5]);
        }
        int i6 = 0;
        while (i6 < this.fullPrecMat.length) {
            int i7 = 0;
            while (i7 < this.fullPrecMat[i6].length) {
                this.fullPrecMat[i6][i7] = ((this.prcStat[i6][i7] - (this.mean[i6] * this.meanStat[i7])) - (this.mean[i7] * this.meanStat[i6])) + (this.numStat * this.mean[i6] * this.mean[i7]);
                this.fullPrecMat[i6][i7] = (((i6 == i7 ? this.priorPrec[i6] : 0.0d) + this.fullPrecMat[i6][i7]) + (this.essMu * this.precTemp[i6][i7])) / ((this.essPrec - this.dimension) + this.numStat);
                i7++;
            }
            this.fullPrecMat[i6][i6] = this.sds[i6] * this.sds[i6];
            i6++;
        }
        double[] dArr2 = (double[]) this.sds.clone();
        int[] iArr = new int[this.sds.length];
        double d = Double.POSITIVE_INFINITY;
        while (d > 1.0E-6d) {
            d = 0.0d;
            Permutation.random(iArr);
            for (int i8 = 0; i8 < iArr.length; i8++) {
                int i9 = iArr[i8];
                double d2 = 0.0d;
                double d3 = 0.0d;
                int i10 = 0;
                while (i10 < this.fullPrecMat[i9].length) {
                    if (this.corrMat[i9][i10] != 0.0d) {
                        d2 = i9 == i10 ? d2 + Math.sqrt(this.fullPrecMat[i9][i10]) : d2 + ((this.fullPrecMat[i9][i10] / this.sds[i10]) / this.corrMat[i9][i10]);
                        d3 += 1.0d;
                    }
                    i10++;
                }
                double d4 = d2 / d3;
                d += (this.sds[i9] - d4) * (this.sds[i9] - d4);
                this.sds[i9] = d4;
            }
        }
        for (int i11 = 0; i11 < this.sds.length; i11++) {
            this.sds[i11] = ((1.0d * this.sds[i11]) + (this.decay * dArr2[i11])) / (this.decay + 1.0d);
        }
        this.decay *= 2.0d;
        for (int i12 = 0; i12 < this.prePrec.length; i12++) {
            for (int i13 = 0; i13 <= i12; i13++) {
                this.prePrec[i12][i13] = this.corrMat[i12][i13] * this.sds[i12] * this.sds[i13];
                this.prePrec[i13][i12] = this.prePrec[i12][i13];
            }
        }
        Inversion.compute(this.prePrec, this.precTemp);
        double[][] dArr3 = this.prePrec;
        this.prePrec = this.precTemp;
        this.precTemp = dArr3;
        this.logDet = Math.log(Determinant.compute(this.prePrec));
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public String getNodeShape(boolean z) {
        return null;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public String getNodeLabel(double d, String str, NumberFormat numberFormat) {
        return "G";
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void setParameters(Emission emission) throws IllegalArgumentException {
        for (int i = 0; i < this.mean.length; i++) {
            this.mean[i] = ((FixedCorrelationGaussianEmission) emission).mean[i];
            this.sds[i] = ((FixedCorrelationGaussianEmission) emission).sds[i];
            for (int i2 = 0; i2 < this.prePrec[i].length; i2++) {
                this.prePrec[i][i2] = ((FixedCorrelationGaussianEmission) emission).prePrec[i][i2];
            }
        }
        this.logDet = ((FixedCorrelationGaussianEmission) emission).logDet;
    }

    public String toString() {
        String str = String.valueOf(Arrays.toString(this.mean)) + AbstractFormatter.DEFAULT_SLICE_SEPARATOR;
        Inversion.compute(this.prePrec, this.precTemp);
        for (int i = 0; i < this.precTemp.length; i++) {
            str = String.valueOf(str) + Arrays.toString(this.precTemp[i]);
        }
        return str;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public /* synthetic */ String toString(NumberFormat numberFormat) {
        throw new Error("Unresolved compilation problem: \n\tThe type FixedCorrelationGaussianEmission must implement the inherited abstract method Emission.toString(NumberFormat)\n");
    }
}
