/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.continuous;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.alphabets.Alphabet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.MultiDimensionalSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import de.jstacs.utils.random.RandomNumberGenerator;
import de.jtem.numericalMethods.algebra.linear.Determinant;
import de.jtem.numericalMethods.algebra.linear.Inversion;
import java.text.NumberFormat;
import java.util.Arrays;
import javax.naming.OperationNotSupportedException;
import umontreal.iro.lecuyer.randvar.NormalGen;
import umontreal.iro.lecuyer.randvarmulti.MultinormalCholeskyGen;
import umontreal.iro.lecuyer.rng.LFSR258;

public class MultivariateGaussianEmission
implements Emission {
    private int dimension;
    private AlphabetContainer con;
    private double[] mean;
    private double[][] precMat;
    private double logDet;
    private double[] tempValues;
    private double[][] precTemp;
    private double[] priorMean;
    private double[][] priorPrecMat;
    private double essMu;
    private double essPrec;
    private double[] meanStat;
    private double[][] prcStat;
    private double numStat;
    private boolean precFixed;
    private boolean meanFixed;

    public MultivariateGaussianEmission(int dimension, double[] priorMean, double[][] priorPrecMat, double essMu, double essPrec, boolean precFixed, boolean meanFixed) throws CloneNotSupportedException {
        this.dimension = dimension;
        this.priorMean = (double[])priorMean.clone();
        this.priorPrecMat = (double[][])ArrayHandler.clone((Cloneable[])priorPrecMat);
        this.essMu = essMu;
        this.essPrec = essPrec;
        this.precFixed = precFixed;
        this.meanFixed = meanFixed;
        this.mean = new double[dimension];
        this.precMat = new double[dimension][dimension];
        this.tempValues = new double[dimension];
        this.precTemp = new double[dimension][dimension];
        this.meanStat = new double[dimension];
        this.prcStat = new double[dimension][dimension];
        this.con = new AlphabetContainer((Alphabet)new ContinuousAlphabet());
    }

    public MultivariateGaussianEmission clone() throws CloneNotSupportedException {
        MultivariateGaussianEmission clone = (MultivariateGaussianEmission)super.clone();
        clone.mean = (double[])this.mean.clone();
        clone.meanStat = (double[])this.meanStat.clone();
        clone.prcStat = (double[][])ArrayHandler.clone((Cloneable[])this.prcStat);
        clone.precMat = (double[][])ArrayHandler.clone((Cloneable[])this.precMat);
        clone.precTemp = (double[][])ArrayHandler.clone((Cloneable[])this.precTemp);
        clone.priorMean = (double[])this.priorMean.clone();
        clone.priorPrecMat = (double[][])ArrayHandler.clone((Cloneable[])this.priorPrecMat);
        clone.tempValues = (double[])this.tempValues.clone();
        return clone;
    }

    @Override
    public StringBuffer toXML() {
        return null;
    }

    @Override
    public AlphabetContainer getAlphabetContainer() {
        return this.con;
    }

    @Override
    public void initializeFunctionRandomly() {
        RandomNumberGenerator rng = new RandomNumberGenerator();
        if (this.precFixed) {
            Inversion.compute(this.priorPrecMat, this.precMat);
        } else {
            int i = 0;
            while (i < this.precMat.length) {
                this.precMat[i][i] = 1.0 / rng.nextGamma(this.essPrec, this.essPrec * this.priorPrecMat[i][i]);
                ++i;
            }
            i = 0;
            while (i < this.precMat.length) {
                int j = 0;
                while (j < i) {
                    this.precMat[i][j] = rng.nextUniform(-Math.sqrt(this.precMat[i][i] * this.precMat[i][j]), Math.sqrt(this.precMat[i][i] * this.precMat[i][j]));
                    this.precMat[j][i] = this.precMat[i][j];
                    ++j;
                }
                ++i;
            }
        }
        System.out.println("prior: " + this.priorMean[0] + " " + this.precMat[0][0]);
        MultinormalCholeskyGen mcg = new MultinormalCholeskyGen(new NormalGen(new LFSR258()), this.priorMean, this.precMat);
        mcg.nextPoint(this.mean);
        Inversion.compute(this.precMat, this.precTemp);
        double[][] temp = this.precMat;
        this.precMat = this.precTemp;
        this.precTemp = temp;
        if (this.meanFixed) {
            this.mean = (double[])this.priorMean.clone();
        }
        if (this.precFixed) {
            try {
                this.precMat = (double[][])ArrayHandler.clone((Cloneable[])this.priorPrecMat);
            }
            catch (Exception e) {
                throw new RuntimeException();
            }
        }
        System.out.println(String.valueOf(this.mean[0]) + " ## " + this.precMat[0][0]);
        this.logDet = Math.log(Determinant.compute(this.precMat));
    }

    private double getLogProbFor(double[] values) {
        double val = 0.0;
        int i = 0;
        while (i < values.length) {
            int j = 0;
            while (j < values.length) {
                val -= (values[i] - this.mean[i]) * this.precMat[i][j] * (values[j] - this.mean[j]);
                ++j;
            }
            ++i;
        }
        val += this.logDet - (double)this.dimension * Math.log(Math.PI * 2);
        return val *= 0.5;
    }

    @Override
    public double getLogProbFor(boolean forward, int startPos, int endPos, Sequence seq) throws OperationNotSupportedException {
        double val = 0.0;
        if (this.dimension == 1) {
            int i = startPos;
            while (i <= endPos) {
                this.tempValues[0] = seq.continuousVal(i);
                val += this.getLogProbFor(this.tempValues);
                ++i;
            }
        } else {
            if (((MultiDimensionalSequence)seq).getNumberOfSequences() != this.dimension) {
                throw new OperationNotSupportedException();
            }
            int i = startPos;
            while (i <= endPos) {
                ((MultiDimensionalSequence)seq).fillContainer(this.tempValues, i);
                val += this.getLogProbFor(this.tempValues);
                ++i;
            }
        }
        return val;
    }

    @Override
    public double getLogPriorTerm() {
        int i;
        double lp = 0.0;
        if (!this.meanFixed) {
            i = 0;
            while (i < this.mean.length) {
                int j = 0;
                while (j < this.mean.length) {
                    lp += (this.priorMean[i] - this.mean[i]) * this.precMat[i][j] * (this.priorMean[j] - this.mean[j]);
                    ++j;
                }
                ++i;
            }
            lp *= -this.essMu / 2.0;
        }
        if (!this.precFixed) {
            lp += (this.essPrec - (double)this.dimension - 1.0) / 2.0 * this.logDet;
            i = 0;
            while (i < this.mean.length) {
                int k = 0;
                while (k < this.mean.length) {
                    lp -= 0.5 * this.priorPrecMat[i][k] * this.precMat[k][i];
                    ++k;
                }
                ++i;
            }
        }
        if (!this.meanFixed) {
            lp += 0.5 * this.logDet;
        }
        return lp;
    }

    @Override
    public void resetStatistic() {
        Arrays.fill(this.meanStat, 0.0);
        int i = 0;
        while (i < this.prcStat.length) {
            Arrays.fill(this.prcStat[i], 0.0);
            ++i;
        }
        this.numStat = 0.0;
    }

    @Override
    public void addToStatistic(boolean forward, int startPos, int endPos, double weight, Sequence seq) throws OperationNotSupportedException {
        if (this.dimension == 1) {
            int j = startPos;
            while (j <= endPos) {
                this.meanStat[0] = this.meanStat[0] + seq.continuousVal(j) * weight;
                double[] dArray = this.prcStat[0];
                dArray[0] = dArray[0] + seq.continuousVal(j) * seq.continuousVal(j) * weight;
                this.numStat += weight;
                ++j;
            }
        } else {
            if (((MultiDimensionalSequence)seq).getNumberOfSequences() != this.dimension) {
                throw new OperationNotSupportedException();
            }
            int j = startPos;
            while (j <= endPos) {
                ((MultiDimensionalSequence)seq).fillContainer(this.tempValues, j);
                int k = 0;
                while (k < this.tempValues.length) {
                    int n = k;
                    this.meanStat[n] = this.meanStat[n] + this.tempValues[k] * weight;
                    int m = 0;
                    while (m < this.prcStat[k].length) {
                        double[] dArray = this.prcStat[k];
                        int n2 = m;
                        dArray[n2] = dArray[n2] + this.tempValues[k] * this.tempValues[m] * weight;
                        ++m;
                    }
                    ++k;
                }
                this.numStat += weight;
                ++j;
            }
        }
    }

    @Override
    public void joinStatistics(Emission ... emissions) {
        int k;
        int j;
        int i = 0;
        while (i < emissions.length) {
            if (emissions[i] != this) {
                j = 0;
                while (j < this.meanStat.length) {
                    int n = j;
                    this.meanStat[n] = this.meanStat[n] + ((MultivariateGaussianEmission)emissions[i]).meanStat[j];
                    k = 0;
                    while (k < this.prcStat[j].length) {
                        double[] dArray = this.prcStat[j];
                        int n2 = k;
                        dArray[n2] = dArray[n2] + ((MultivariateGaussianEmission)emissions[i]).prcStat[j][k];
                        ++k;
                    }
                    ++j;
                }
                this.numStat += ((MultivariateGaussianEmission)emissions[i]).numStat;
            }
            ++i;
        }
        i = 0;
        while (i < emissions.length) {
            if (emissions[i] != this) {
                j = 0;
                while (j < this.meanStat.length) {
                    ((MultivariateGaussianEmission)emissions[i]).meanStat[j] = this.meanStat[j];
                    k = 0;
                    while (k < this.prcStat[j].length) {
                        ((MultivariateGaussianEmission)emissions[i]).prcStat[j][k] = this.prcStat[j][k];
                        ++k;
                    }
                    ++j;
                }
                ((MultivariateGaussianEmission)emissions[i]).numStat = this.numStat;
            }
            ++i;
        }
    }

    @Override
    public void estimateFromStatistic() {
        int i;
        if (this.meanFixed) {
            this.mean = (double[])this.priorMean.clone();
        } else {
            i = 0;
            while (i < this.mean.length) {
                this.mean[i] = this.essMu * this.priorMean[i] + this.meanStat[i];
                int n = i++;
                this.mean[n] = this.mean[n] / (this.numStat + this.essMu);
            }
        }
        if (this.precFixed) {
            try {
                this.precMat = (double[][])ArrayHandler.clone((Cloneable[])this.priorPrecMat);
            }
            catch (CloneNotSupportedException ex) {
                throw new RuntimeException();
            }
        } else {
            int j;
            i = 0;
            while (i < this.mean.length) {
                j = 0;
                while (j < this.mean.length) {
                    this.precTemp[i][j] = this.meanFixed ? 0.0 : (this.mean[i] - this.priorMean[i]) * (this.mean[j] - this.priorMean[j]);
                    ++j;
                }
                ++i;
            }
            i = 0;
            while (i < this.precMat.length) {
                j = 0;
                while (j < this.precMat[i].length) {
                    this.precMat[i][j] = this.prcStat[i][j] - this.mean[i] * this.meanStat[j] - this.mean[j] * this.meanStat[i] + this.numStat * this.mean[i] * this.mean[j];
                    this.precMat[i][j] = (this.priorPrecMat[i][j] + this.precMat[i][j] + this.essMu * this.precTemp[i][j]) / (this.essPrec - (double)this.dimension + this.numStat);
                    ++j;
                }
                ++i;
            }
            Inversion.compute(this.precMat, this.precTemp);
            double[][] temp = this.precMat;
            this.precMat = this.precTemp;
            this.precTemp = temp;
        }
        this.logDet = Math.log(Determinant.compute(this.precMat));
    }

    @Override
    public String getNodeShape(boolean forward) {
        return null;
    }

    @Override
    public String getNodeLabel(double weight, String name, NumberFormat nf) {
        return "G";
    }

    @Override
    public void setParameters(Emission t) throws IllegalArgumentException {
        int i = 0;
        while (i < this.mean.length) {
            this.mean[i] = ((MultivariateGaussianEmission)t).mean[i];
            int j = 0;
            while (j < this.precMat[i].length) {
                this.precMat[i][j] = ((MultivariateGaussianEmission)t).precMat[i][j];
                ++j;
            }
            ++i;
        }
        this.logDet = Math.log(Determinant.compute(this.precMat));
    }

    public String toString() {
        String str = String.valueOf(Arrays.toString(this.mean)) + "\n\n";
        Inversion.compute(this.precMat, this.precTemp);
        int i = 0;
        while (i < this.precTemp.length) {
            str = String.valueOf(str) + Arrays.toString(this.precTemp[i]);
            ++i;
        }
        return str;
    }
}

