/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.continuous;

import Jama.Matrix;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.alphabets.Alphabet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import javax.naming.OperationNotSupportedException;

public class MultivariateGaussianEmission
implements Emission {
    int dim;
    double[] initialMean;
    double[] mean;
    double[] initialSds;
    double[] sds;
    double[][] initialCorrelation;
    double[][] correlation;
    private Matrix inverseCov;
    private double[] aprioriMean;
    private double scaleMean;
    private double shapeSd;
    private double[][] scaleSd;
    protected HashMap<Sequence, double[]> gammas;
    private double sumOfGammas;
    private double[] sumOfGammaWeightedEmissions;
    private AlphabetContainer con;
    private static final String TAG = "MultivariateGaussianEmission";
    double[] emission;

    public MultivariateGaussianEmission(double[] mean, double[] sds, double[][] correlation, double scaleMean, double[] aprioriMean, double shapeSd, double[][] scaleSd) {
        this.dim = mean.length;
        this.emission = new double[this.dim];
        this.initialMean = (double[])mean.clone();
        this.initialSds = (double[])sds.clone();
        this.initialCorrelation = (double[][])correlation.clone();
        this.mean = (double[])mean.clone();
        this.sds = (double[])sds.clone();
        this.correlation = (double[][])correlation.clone();
        this.inverseCov = this.getInverseCovarianceMatrix();
        this.scaleMean = scaleMean;
        this.aprioriMean = (double[])aprioriMean.clone();
        this.shapeSd = shapeSd;
        this.scaleSd = (double[][])scaleSd.clone();
        this.gammas = new HashMap();
        this.con = new AlphabetContainer((Alphabet)new ContinuousAlphabet());
    }

    public MultivariateGaussianEmission(StringBuffer xml) throws NonParsableException {
        this.fromXML(xml);
        this.gammas = new HashMap();
        this.resetStatistic();
    }

    @Override
    public void addToStatistic(boolean forward, int startPos, int endPos, double weight, Sequence seq) throws OperationNotSupportedException {
        if (!forward) {
            throw new OperationNotSupportedException();
        }
        boolean contained = this.gammas.containsKey(seq);
        if (!contained) {
            this.gammas.put(seq, new double[seq.getLength()]);
        }
        double[] weights = this.gammas.get(seq);
        int pos = startPos;
        while (pos <= endPos) {
            int n = pos;
            weights[n] = weights[n] + weight;
            seq.fillContainer(this.emission, pos);
            this.sumOfGammas += weight;
            int d = 0;
            while (d < this.dim) {
                int n2 = d;
                this.sumOfGammaWeightedEmissions[n2] = this.sumOfGammaWeightedEmissions[n2] + weight * this.emission[d];
                ++d;
            }
            ++pos;
        }
    }

    @Override
    public void joinStatistics(Emission ... emissions) {
        double[] v;
        MultivariateGaussianEmission c;
        int i = 0;
        while (i < emissions.length) {
            if (emissions[i] != this) {
                c = (MultivariateGaussianEmission)emissions[i];
                this.sumOfGammas += c.sumOfGammas;
                int d = 0;
                while (d < this.dim) {
                    int n = d;
                    this.sumOfGammaWeightedEmissions[n] = this.sumOfGammaWeightedEmissions[n] + c.sumOfGammaWeightedEmissions[d];
                    ++d;
                }
                for (Map.Entry<Sequence, double[]> e : c.gammas.entrySet()) {
                    v = this.gammas.get(e.getKey());
                    if (v == null) {
                        this.gammas.put(e.getKey(), (double[])e.getValue().clone());
                        continue;
                    }
                    double[] w = e.getValue();
                    int j = 0;
                    while (j < w.length) {
                        int n = j;
                        v[n] = v[n] + w[j];
                        ++j;
                    }
                    this.gammas.put(e.getKey(), v);
                }
            }
            ++i;
        }
        i = 0;
        while (i < emissions.length) {
            if (emissions[i] != this) {
                c = (MultivariateGaussianEmission)emissions[i];
                c.sumOfGammas = this.sumOfGammas;
                int d = 0;
                while (d < this.dim) {
                    c.sumOfGammaWeightedEmissions[d] = this.sumOfGammaWeightedEmissions[d];
                    ++d;
                }
                c.gammas.clear();
                c.resetGammas();
                for (Map.Entry<Sequence, double[]> e : this.gammas.entrySet()) {
                    v = c.gammas.get(e.getKey());
                    if (v == null) {
                        c.gammas.put(e.getKey(), (double[])e.getValue().clone());
                        continue;
                    }
                    System.arraycopy(e.getValue(), 0, v, 0, v.length);
                }
            }
            ++i;
        }
    }

    @Override
    public void estimateFromStatistic() {
        int d = 0;
        while (d < this.dim) {
            this.mean[d] = (this.sumOfGammaWeightedEmissions[d] + this.aprioriMean[d] * this.scaleMean) / (this.sumOfGammas + this.scaleMean);
            ++d;
        }
        Iterator<Map.Entry<Sequence, double[]>> keyIterator = this.gammas.entrySet().iterator();
        Matrix cov_numerator = new Matrix(this.scaleSd);
        Matrix dummy = new Matrix((double[][])new double[][]{this.mean});
        dummy = dummy.minus(new Matrix((double[][])new double[][]{this.aprioriMean}));
        cov_numerator = cov_numerator.plus(dummy.transpose().times(dummy).times(this.scaleMean));
        double cov_denominator = this.sumOfGammas + this.shapeSd - (double)this.dim;
        do {
            Map.Entry<Sequence, double[]> entry = keyIterator.next();
            Sequence seq = entry.getKey();
            double[] weights = entry.getValue();
            int T = weights.length;
            int t = 0;
            while (t < T) {
                seq.fillContainer(this.emission, t);
                dummy = new Matrix((double[][])new double[][]{this.emission});
                dummy = dummy.minus(new Matrix((double[][])new double[][]{this.mean}));
                cov_numerator = cov_numerator.plus(dummy.transpose().times(dummy).times(weights[t]));
                ++t;
            }
        } while (keyIterator.hasNext());
        Matrix cov = cov_numerator.times(1.0 / cov_denominator);
        int d2 = 0;
        while (d2 < this.dim) {
            this.sds[d2] = Math.sqrt(cov.get(d2, d2));
            ++d2;
        }
        this.correlation = this.getCorrelations(cov, this.sds);
        this.inverseCov = cov.inverse();
    }

    @Override
    public AlphabetContainer getAlphabetContainer() {
        return this.con;
    }

    @Override
    public double getLogPriorTerm() {
        double res = 0.0;
        res += (this.shapeSd - (double)this.dim) / 2.0 * Math.log(this.inverseCov.det());
        Matrix dummy = new Matrix((double[][])new double[][]{this.mean});
        dummy = dummy.minus(new Matrix((double[][])new double[][]{this.aprioriMean}));
        Matrix dummyRes = dummy.times(this.inverseCov).times(dummy.transpose());
        res -= this.scaleMean / 2.0 * dummyRes.get(0, 0);
        dummy = new Matrix(this.scaleSd);
        dummyRes = dummy.times(this.inverseCov);
        return res -= 0.5 * dummyRes.trace();
    }

    @Override
    public double getLogProbFor(boolean forward, int startPos, int endPos, Sequence seq) throws OperationNotSupportedException {
        double res = 0.0;
        int pos = startPos;
        while (pos <= endPos) {
            seq.fillContainer(this.emission, pos);
            res -= Math.log(Math.sqrt(Math.pow(Math.PI * 2, this.dim)));
            res += 0.5 * Math.log(this.inverseCov.det());
            Matrix dummy = new Matrix((double[][])new double[][]{this.emission});
            dummy = dummy.minus(new Matrix((double[][])new double[][]{this.mean}));
            dummy = dummy.times(this.inverseCov).times(dummy.transpose());
            res -= 0.5 * dummy.get(0, 0);
            ++pos;
        }
        return res;
    }

    @Override
    public String getNodeLabel(double weight, String name, NumberFormat nf) {
        return null;
    }

    @Override
    public String getNodeShape(boolean forward) {
        return null;
    }

    @Override
    public void initializeFunctionRandomly() {
        this.mean = (double[])this.initialMean.clone();
        this.sds = (double[])this.initialSds.clone();
        this.correlation = (double[][])this.initialCorrelation.clone();
        this.inverseCov = this.getInverseCovarianceMatrix();
    }

    @Override
    public void resetStatistic() {
        this.sumOfGammas = 0.0;
        if (this.sumOfGammaWeightedEmissions == null) {
            this.sumOfGammaWeightedEmissions = new double[this.dim];
        } else {
            Arrays.fill(this.sumOfGammaWeightedEmissions, 0.0);
        }
        this.resetGammas();
    }

    private void resetGammas() {
        if (!this.gammas.isEmpty()) {
            Iterator<Map.Entry<Sequence, double[]>> keyIterator = this.gammas.entrySet().iterator();
            do {
                Map.Entry<Sequence, double[]> entry = keyIterator.next();
                Sequence seq = entry.getKey();
                double[] weights = entry.getValue();
                Arrays.fill(weights, 0.0);
            } while (keyIterator.hasNext());
        }
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer buf = new StringBuffer();
        XMLParser.appendObjectWithTags(buf, this.con, "alphabet");
        XMLParser.appendObjectWithTags(buf, this.initialMean, "initialMean");
        XMLParser.appendObjectWithTags(buf, this.mean, "mean");
        XMLParser.appendObjectWithTags(buf, this.initialSds, "initialSds");
        XMLParser.appendObjectWithTags(buf, this.sds, "sds");
        XMLParser.appendObjectWithTags(buf, this.initialCorrelation, "initialCorrelation");
        XMLParser.appendObjectWithTags(buf, this.correlation, "correlation");
        XMLParser.appendObjectWithTags(buf, this.aprioriMean, "aprioriMean");
        XMLParser.appendObjectWithTags(buf, this.scaleMean, "scaleMean");
        XMLParser.appendObjectWithTags(buf, this.shapeSd, "shapeSd");
        XMLParser.appendObjectWithTags(buf, this.scaleSd, "scaleSd");
        XMLParser.addTags(buf, TAG);
        return buf;
    }

    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, TAG);
        this.con = XMLParser.extractObjectForTags(xml, "alphabet", AlphabetContainer.class);
        this.initialMean = (double[])XMLParser.extractObjectForTags(xml, "initialMean");
        this.mean = (double[])XMLParser.extractObjectForTags(xml, "mean");
        this.initialSds = (double[])XMLParser.extractObjectForTags(xml, "initialSds");
        this.sds = (double[])XMLParser.extractObjectForTags(xml, "sds");
        this.initialCorrelation = (double[][])XMLParser.extractObjectForTags(xml, "initialCorrelation");
        this.correlation = (double[][])XMLParser.extractObjectForTags(xml, "correlation");
        this.aprioriMean = (double[])XMLParser.extractObjectForTags(xml, "aprioriMean");
        this.scaleMean = XMLParser.extractObjectForTags(xml, "scaleMean", Double.TYPE);
        this.shapeSd = XMLParser.extractObjectForTags(xml, "shapeSd", Double.TYPE);
        this.scaleSd = (double[][])XMLParser.extractObjectForTags(xml, "scaleSd");
        this.dim = this.mean.length;
        this.emission = new double[this.dim];
        this.inverseCov = this.getInverseCovarianceMatrix();
    }

    @Override
    public String toString(NumberFormat nf) {
        String res = "- Means  = ";
        int i = 0;
        while (i < this.dim) {
            res = String.valueOf(res) + nf.format(this.mean[i]) + "\t";
            ++i;
        }
        res = String.valueOf(res) + "\n\n";
        i = 0;
        while (i < this.dim) {
            res = String.valueOf(res) + "- Standard dev. = " + nf.format(this.sds[i]) + "\n";
            ++i;
        }
        res = String.valueOf(res) + "\n\n";
        i = 0;
        while (i < this.dim) {
            int j = i + 1;
            while (j < this.dim) {
                res = String.valueOf(res) + "- Correlation(" + (i + 1) + "," + (j + 1) + ")  = ";
                res = String.valueOf(res) + nf.format(this.correlation[i][j]) + "\n";
                ++j;
            }
            ++i;
        }
        return res;
    }

    public MultivariateGaussianEmission clone() throws CloneNotSupportedException {
        MultivariateGaussianEmission clone = (MultivariateGaussianEmission)super.clone();
        clone.correlation = (double[][])ArrayHandler.clone((Cloneable[])this.correlation);
        double[] dArray = clone.emission = this.emission == null ? null : (double[])this.emission.clone();
        if (this.gammas != null) {
            clone.gammas = new HashMap();
            for (Map.Entry<Sequence, double[]> e : this.gammas.entrySet()) {
                clone.gammas.put(e.getKey(), (double[])e.getValue().clone());
            }
        } else {
            clone.gammas = null;
        }
        clone.initialCorrelation = (double[][])ArrayHandler.clone((Cloneable[])this.initialCorrelation);
        clone.initialMean = this.initialMean == null ? null : (double[])this.initialMean.clone();
        clone.initialSds = this.initialSds == null ? null : (double[])this.initialSds.clone();
        clone.inverseCov = this.inverseCov == null ? null : new Matrix(this.inverseCov.getArray());
        clone.scaleSd = (double[][])ArrayHandler.clone((Cloneable[])this.scaleSd);
        clone.sds = this.sds == null ? null : (double[])this.sds.clone();
        clone.sumOfGammaWeightedEmissions = this.sumOfGammaWeightedEmissions == null ? null : (double[])this.sumOfGammaWeightedEmissions.clone();
        return clone;
    }

    private Matrix getInverseCovarianceMatrix() {
        return this.getCovarianceMatrix().inverse();
    }

    private Matrix getCovarianceMatrix() {
        Matrix dummy = new Matrix(this.dim, this.dim, 0.0);
        int i = 0;
        while (i < this.dim) {
            dummy.set(i, i, Math.pow(this.sds[i], 2.0));
            ++i;
        }
        i = 0;
        while (i < this.dim) {
            int j = i + 1;
            while (j < this.dim) {
                double covVar = this.sds[i] * this.sds[j] * this.correlation[i][j];
                dummy.set(i, j, covVar);
                dummy.set(j, i, covVar);
                ++j;
            }
            ++i;
        }
        return dummy;
    }

    private double[][] getCorrelations(Matrix cov, double[] sds) {
        double[][] res = new double[this.dim][this.dim];
        int i = 0;
        while (i < this.dim) {
            int j = i + 1;
            while (j < this.dim) {
                res[i][j] = cov.get(i, j) / sds[i] / sds[j];
                ++j;
            }
            ++i;
        }
        return res;
    }

    @Override
    public void setParameters(Emission t) throws IllegalArgumentException {
        if (!t.getClass().equals(this.getClass()) || ((MultivariateGaussianEmission)t).dim != this.dim) {
            throw new IllegalArgumentException("The transitions are not comparable.");
        }
        MultivariateGaussianEmission c = (MultivariateGaussianEmission)t;
        System.arraycopy(c.mean, 0, this.mean, 0, this.mean.length);
        System.arraycopy(c.sds, 0, this.sds, 0, this.sds.length);
        this.inverseCov = c.inverseCov.copy();
        int i = 0;
        while (i < this.correlation.length) {
            System.arraycopy(c.correlation[i], 0, this.correlation[i], 0, this.correlation[i].length);
            ++i;
        }
    }
}

