package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.continuous;

import Jama.Matrix;
import cern.colt.matrix.impl.AbstractFormatter;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import htsjdk.variant.vcf.VCFConstants;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import javax.naming.OperationNotSupportedException;
import projects.dispom.DispomParameterSet;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/states/emissions/continuous/MultivariateGaussianEmission.class */
public class MultivariateGaussianEmission implements Emission {
    int dim;
    double[] initialMean;
    double[] mean;
    double[] initialSds;
    double[] sds;
    double[][] initialCorrelation;
    double[][] correlation;
    private Matrix inverseCov;
    private double[] aprioriMean;
    private double scaleMean;
    private double shapeSd;
    private double[][] scaleSd;
    protected HashMap<Sequence, double[]> gammas;
    private double sumOfGammas;
    private double[] sumOfGammaWeightedEmissions;
    private AlphabetContainer con;
    private static final String TAG = "MultivariateGaussianEmission";
    double[] emission;

    public MultivariateGaussianEmission(double[] dArr, double[] dArr2, double[][] dArr3, double d, double[] dArr4, double d2, double[][] dArr5) {
        this.dim = dArr.length;
        this.emission = new double[this.dim];
        this.initialMean = (double[]) dArr.clone();
        this.initialSds = (double[]) dArr2.clone();
        this.initialCorrelation = (double[][]) dArr3.clone();
        this.mean = (double[]) dArr.clone();
        this.sds = (double[]) dArr2.clone();
        this.correlation = (double[][]) dArr3.clone();
        this.inverseCov = getInverseCovarianceMatrix();
        this.scaleMean = d;
        this.aprioriMean = (double[]) dArr4.clone();
        this.shapeSd = d2;
        this.scaleSd = (double[][]) dArr5.clone();
        this.gammas = new HashMap<>();
        this.con = new AlphabetContainer(new ContinuousAlphabet());
    }

    public MultivariateGaussianEmission(StringBuffer stringBuffer) throws NonParsableException {
        fromXML(stringBuffer);
        this.gammas = new HashMap<>();
        resetStatistic();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void addToStatistic(boolean z, int i, int i2, double d, Sequence sequence) throws OperationNotSupportedException {
        if (!z) {
            throw new OperationNotSupportedException();
        }
        if (!this.gammas.containsKey(sequence)) {
            this.gammas.put(sequence, new double[sequence.getLength()]);
        }
        double[] dArr = this.gammas.get(sequence);
        for (int i3 = i; i3 <= i2; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] + d;
            sequence.fillContainer(this.emission, i3);
            this.sumOfGammas += d;
            for (int i5 = 0; i5 < this.dim; i5++) {
                double[] dArr2 = this.sumOfGammaWeightedEmissions;
                int i6 = i5;
                dArr2[i6] = dArr2[i6] + (d * this.emission[i5]);
            }
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void joinStatistics(Emission... emissionArr) {
        for (int i = 0; i < emissionArr.length; i++) {
            if (emissionArr[i] != this) {
                MultivariateGaussianEmission multivariateGaussianEmission = (MultivariateGaussianEmission) emissionArr[i];
                this.sumOfGammas += multivariateGaussianEmission.sumOfGammas;
                for (int i2 = 0; i2 < this.dim; i2++) {
                    double[] dArr = this.sumOfGammaWeightedEmissions;
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + multivariateGaussianEmission.sumOfGammaWeightedEmissions[i2];
                }
                for (Map.Entry<Sequence, double[]> entry : multivariateGaussianEmission.gammas.entrySet()) {
                    double[] dArr2 = this.gammas.get(entry.getKey());
                    if (dArr2 == null) {
                        this.gammas.put(entry.getKey(), (double[]) entry.getValue().clone());
                    } else {
                        double[] value = entry.getValue();
                        for (int i4 = 0; i4 < value.length; i4++) {
                            int i5 = i4;
                            dArr2[i5] = dArr2[i5] + value[i4];
                        }
                        this.gammas.put(entry.getKey(), dArr2);
                    }
                }
            }
        }
        for (int i6 = 0; i6 < emissionArr.length; i6++) {
            if (emissionArr[i6] != this) {
                MultivariateGaussianEmission multivariateGaussianEmission2 = (MultivariateGaussianEmission) emissionArr[i6];
                multivariateGaussianEmission2.sumOfGammas = this.sumOfGammas;
                for (int i7 = 0; i7 < this.dim; i7++) {
                    multivariateGaussianEmission2.sumOfGammaWeightedEmissions[i7] = this.sumOfGammaWeightedEmissions[i7];
                }
                multivariateGaussianEmission2.gammas.clear();
                multivariateGaussianEmission2.resetGammas();
                for (Map.Entry<Sequence, double[]> entry2 : this.gammas.entrySet()) {
                    double[] dArr3 = multivariateGaussianEmission2.gammas.get(entry2.getKey());
                    if (dArr3 == null) {
                        multivariateGaussianEmission2.gammas.put(entry2.getKey(), (double[]) entry2.getValue().clone());
                    } else {
                        System.arraycopy(entry2.getValue(), 0, dArr3, 0, dArr3.length);
                    }
                }
            }
        }
    }

    /* JADX WARN: Type inference failed for: r2v15, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r2v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v2, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v8, types: [double[], double[][]] */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void estimateFromStatistic() {
        for (int i = 0; i < this.dim; i++) {
            this.mean[i] = (this.sumOfGammaWeightedEmissions[i] + (this.aprioriMean[i] * this.scaleMean)) / (this.sumOfGammas + this.scaleMean);
        }
        Iterator<Map.Entry<Sequence, double[]>> it = this.gammas.entrySet().iterator();
        Matrix matrix = new Matrix(this.scaleSd);
        Matrix minus = new Matrix(new double[]{this.mean}).minus(new Matrix(new double[]{this.aprioriMean}));
        Matrix plus = matrix.plus(minus.transpose().times(minus).times(this.scaleMean));
        double d = (this.sumOfGammas + this.shapeSd) - this.dim;
        do {
            Map.Entry<Sequence, double[]> next = it.next();
            Sequence key = next.getKey();
            double[] value = next.getValue();
            int length = value.length;
            for (int i2 = 0; i2 < length; i2++) {
                key.fillContainer(this.emission, i2);
                Matrix minus2 = new Matrix(new double[]{this.emission}).minus(new Matrix(new double[]{this.mean}));
                plus = plus.plus(minus2.transpose().times(minus2).times(value[i2]));
            }
        } while (it.hasNext());
        Matrix times = plus.times(1.0d / d);
        for (int i3 = 0; i3 < this.dim; i3++) {
            this.sds[i3] = Math.sqrt(times.get(i3, i3));
        }
        this.correlation = getCorrelations(times, this.sds);
        this.inverseCov = times.inverse();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public AlphabetContainer getAlphabetContainer() {
        return this.con;
    }

    /* JADX WARN: Type inference failed for: r2v9, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v2, types: [double[], double[][]] */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public double getLogPriorTerm() {
        double log = 0.0d + (((this.shapeSd - this.dim) / 2.0d) * Math.log(this.inverseCov.det()));
        Matrix minus = new Matrix(new double[]{this.mean}).minus(new Matrix(new double[]{this.aprioriMean}));
        return (log - ((this.scaleMean / 2.0d) * minus.times(this.inverseCov).times(minus.transpose()).get(0, 0))) - (0.5d * new Matrix(this.scaleSd).times(this.inverseCov).trace());
    }

    /* JADX WARN: Type inference failed for: r2v9, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v2, types: [double[], double[][]] */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public double getLogProbFor(boolean z, int i, int i2, Sequence sequence) throws OperationNotSupportedException {
        double d = 0.0d;
        for (int i3 = i; i3 <= i2; i3++) {
            sequence.fillContainer(this.emission, i3);
            double log = (d - Math.log(Math.sqrt(Math.pow(6.283185307179586d, this.dim)))) + (0.5d * Math.log(this.inverseCov.det()));
            Matrix minus = new Matrix(new double[]{this.emission}).minus(new Matrix(new double[]{this.mean}));
            d = log - (0.5d * minus.times(this.inverseCov).times(minus.transpose()).get(0, 0));
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public String getNodeLabel(double d, String str, NumberFormat numberFormat) {
        return null;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public String getNodeShape(boolean z) {
        return null;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void initializeFunctionRandomly() {
        this.mean = (double[]) this.initialMean.clone();
        this.sds = (double[]) this.initialSds.clone();
        this.correlation = (double[][]) this.initialCorrelation.clone();
        this.inverseCov = getInverseCovarianceMatrix();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void resetStatistic() {
        this.sumOfGammas = 0.0d;
        if (this.sumOfGammaWeightedEmissions == null) {
            this.sumOfGammaWeightedEmissions = new double[this.dim];
        } else {
            Arrays.fill(this.sumOfGammaWeightedEmissions, 0.0d);
        }
        resetGammas();
    }

    private void resetGammas() {
        if (this.gammas.isEmpty()) {
            return;
        }
        Iterator<Map.Entry<Sequence, double[]>> it = this.gammas.entrySet().iterator();
        do {
            Map.Entry<Sequence, double[]> next = it.next();
            next.getKey();
            Arrays.fill(next.getValue(), 0.0d);
        } while (it.hasNext());
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer();
        XMLParser.appendObjectWithTags(stringBuffer, this.con, "alphabet");
        XMLParser.appendObjectWithTags(stringBuffer, this.initialMean, "initialMean");
        XMLParser.appendObjectWithTags(stringBuffer, this.mean, DispomParameterSet.MEAN);
        XMLParser.appendObjectWithTags(stringBuffer, this.initialSds, "initialSds");
        XMLParser.appendObjectWithTags(stringBuffer, this.sds, "sds");
        XMLParser.appendObjectWithTags(stringBuffer, this.initialCorrelation, "initialCorrelation");
        XMLParser.appendObjectWithTags(stringBuffer, this.correlation, "correlation");
        XMLParser.appendObjectWithTags(stringBuffer, this.aprioriMean, "aprioriMean");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.scaleMean), "scaleMean");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.shapeSd), "shapeSd");
        XMLParser.appendObjectWithTags(stringBuffer, this.scaleSd, "scaleSd");
        XMLParser.addTags(stringBuffer, TAG);
        return stringBuffer;
    }

    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, TAG);
        this.con = (AlphabetContainer) XMLParser.extractObjectForTags(extractForTag, "alphabet", AlphabetContainer.class);
        this.initialMean = (double[]) XMLParser.extractObjectForTags(extractForTag, "initialMean");
        this.mean = (double[]) XMLParser.extractObjectForTags(extractForTag, DispomParameterSet.MEAN);
        this.initialSds = (double[]) XMLParser.extractObjectForTags(extractForTag, "initialSds");
        this.sds = (double[]) XMLParser.extractObjectForTags(extractForTag, "sds");
        this.initialCorrelation = (double[][]) XMLParser.extractObjectForTags(extractForTag, "initialCorrelation");
        this.correlation = (double[][]) XMLParser.extractObjectForTags(extractForTag, "correlation");
        this.aprioriMean = (double[]) XMLParser.extractObjectForTags(extractForTag, "aprioriMean");
        this.scaleMean = ((Double) XMLParser.extractObjectForTags(extractForTag, "scaleMean", Double.TYPE)).doubleValue();
        this.shapeSd = ((Double) XMLParser.extractObjectForTags(extractForTag, "shapeSd", Double.TYPE)).doubleValue();
        this.scaleSd = (double[][]) XMLParser.extractObjectForTags(extractForTag, "scaleSd");
        this.dim = this.mean.length;
        this.emission = new double[this.dim];
        this.inverseCov = getInverseCovarianceMatrix();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public String toString(NumberFormat numberFormat) {
        String str = "- Means  = ";
        for (int i = 0; i < this.dim; i++) {
            str = String.valueOf(str) + numberFormat.format(this.mean[i]) + "\t";
        }
        String str2 = String.valueOf(str) + AbstractFormatter.DEFAULT_SLICE_SEPARATOR;
        for (int i2 = 0; i2 < this.dim; i2++) {
            str2 = String.valueOf(str2) + "- Standard dev. = " + numberFormat.format(this.sds[i2]) + AbstractFormatter.DEFAULT_ROW_SEPARATOR;
        }
        String str3 = String.valueOf(str2) + AbstractFormatter.DEFAULT_SLICE_SEPARATOR;
        for (int i3 = 0; i3 < this.dim; i3++) {
            for (int i4 = i3 + 1; i4 < this.dim; i4++) {
                str3 = String.valueOf(String.valueOf(str3) + "- Correlation(" + (i3 + 1) + VCFConstants.INFO_FIELD_ARRAY_SEPARATOR + (i4 + 1) + ")  = ") + numberFormat.format(this.correlation[i3][i4]) + AbstractFormatter.DEFAULT_ROW_SEPARATOR;
            }
        }
        return str3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[][], java.lang.Cloneable[]] */
    /* JADX WARN: Type inference failed for: r1v13, types: [double[][], java.lang.Cloneable[]] */
    /* JADX WARN: Type inference failed for: r1v35, types: [double[][], java.lang.Cloneable[]] */
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public MultivariateGaussianEmission m147clone() throws CloneNotSupportedException {
        MultivariateGaussianEmission multivariateGaussianEmission = (MultivariateGaussianEmission) super.clone();
        multivariateGaussianEmission.correlation = (double[][]) ArrayHandler.clone(this.correlation);
        multivariateGaussianEmission.emission = this.emission == null ? null : (double[]) this.emission.clone();
        if (this.gammas != null) {
            multivariateGaussianEmission.gammas = new HashMap<>();
            for (Map.Entry<Sequence, double[]> entry : this.gammas.entrySet()) {
                multivariateGaussianEmission.gammas.put(entry.getKey(), (double[]) entry.getValue().clone());
            }
        } else {
            multivariateGaussianEmission.gammas = null;
        }
        multivariateGaussianEmission.initialCorrelation = (double[][]) ArrayHandler.clone(this.initialCorrelation);
        multivariateGaussianEmission.initialMean = this.initialMean == null ? null : (double[]) this.initialMean.clone();
        multivariateGaussianEmission.initialSds = this.initialSds == null ? null : (double[]) this.initialSds.clone();
        multivariateGaussianEmission.inverseCov = this.inverseCov == null ? null : new Matrix(this.inverseCov.getArray());
        multivariateGaussianEmission.scaleSd = (double[][]) ArrayHandler.clone(this.scaleSd);
        multivariateGaussianEmission.sds = this.sds == null ? null : (double[]) this.sds.clone();
        multivariateGaussianEmission.sumOfGammaWeightedEmissions = this.sumOfGammaWeightedEmissions == null ? null : (double[]) this.sumOfGammaWeightedEmissions.clone();
        return multivariateGaussianEmission;
    }

    private Matrix getInverseCovarianceMatrix() {
        return getCovarianceMatrix().inverse();
    }

    private Matrix getCovarianceMatrix() {
        Matrix matrix = new Matrix(this.dim, this.dim, 0.0d);
        for (int i = 0; i < this.dim; i++) {
            matrix.set(i, i, Math.pow(this.sds[i], 2.0d));
        }
        for (int i2 = 0; i2 < this.dim; i2++) {
            for (int i3 = i2 + 1; i3 < this.dim; i3++) {
                double d = this.sds[i2] * this.sds[i3] * this.correlation[i2][i3];
                matrix.set(i2, i3, d);
                matrix.set(i3, i2, d);
            }
        }
        return matrix;
    }

    private double[][] getCorrelations(Matrix matrix, double[] dArr) {
        double[][] dArr2 = new double[this.dim][this.dim];
        for (int i = 0; i < this.dim; i++) {
            for (int i2 = i + 1; i2 < this.dim; i2++) {
                dArr2[i][i2] = (matrix.get(i, i2) / dArr[i]) / dArr[i2];
            }
        }
        return dArr2;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission
    public void setParameters(Emission emission) throws IllegalArgumentException {
        if (!emission.getClass().equals(getClass()) || ((MultivariateGaussianEmission) emission).dim != this.dim) {
            throw new IllegalArgumentException("The transitions are not comparable.");
        }
        MultivariateGaussianEmission multivariateGaussianEmission = (MultivariateGaussianEmission) emission;
        System.arraycopy(multivariateGaussianEmission.mean, 0, this.mean, 0, this.mean.length);
        System.arraycopy(multivariateGaussianEmission.sds, 0, this.sds, 0, this.sds.length);
        this.inverseCov = multivariateGaussianEmission.inverseCov.copy();
        for (int i = 0; i < this.correlation.length; i++) {
            System.arraycopy(multivariateGaussianEmission.correlation[i], 0, this.correlation[i], 0, this.correlation[i].length);
        }
    }
}
