/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.continuous;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.alphabets.Alphabet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.MultiDimensionalSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import de.jstacs.utils.random.RandomNumberGenerator;
import de.jtem.numericalMethods.algebra.group.Permutation;
import de.jtem.numericalMethods.algebra.linear.Determinant;
import de.jtem.numericalMethods.algebra.linear.Inversion;
import java.text.NumberFormat;
import java.util.Arrays;
import javax.naming.OperationNotSupportedException;
import umontreal.iro.lecuyer.randvar.NormalGen;
import umontreal.iro.lecuyer.randvarmulti.MultinormalCholeskyGen;
import umontreal.iro.lecuyer.rng.LFSR258;

public class FixedCorrelationGaussianEmission
implements Emission {
    private int dimension;
    private AlphabetContainer con;
    private double[] mean;
    private double[][] prePrec;
    private double[] sds;
    private double logDet;
    private double[] tempValues;
    private double[] priorMean;
    private double[][] corrMat;
    private double[][] invCorrMat;
    private double[][] precTemp;
    private double[][] fullPrecMat;
    private double[] priorPrec;
    private double essMu;
    private double essPrec;
    private double[] meanStat;
    private double[][] prcStat;
    private double numStat;
    private boolean meanFixed;
    private double decay;

    public FixedCorrelationGaussianEmission(int dimension, double[] priorMean, double[] priorPrec, double essMu, double essPrec, double[][] corrMat, boolean meanFixed) throws CloneNotSupportedException {
        this.dimension = dimension;
        this.priorMean = (double[])priorMean.clone();
        this.priorPrec = (double[])priorPrec.clone();
        this.essMu = essMu;
        this.essPrec = essPrec;
        this.corrMat = (double[][])ArrayHandler.clone((Cloneable[])corrMat);
        this.invCorrMat = new double[dimension][dimension];
        Inversion.compute(corrMat, this.invCorrMat);
        this.meanFixed = meanFixed;
        this.mean = new double[dimension];
        this.prePrec = new double[dimension][dimension];
        this.fullPrecMat = new double[dimension][dimension];
        this.precTemp = new double[dimension][dimension];
        this.sds = new double[dimension];
        this.tempValues = new double[dimension];
        this.meanStat = new double[dimension];
        this.prcStat = new double[dimension][dimension];
        this.con = new AlphabetContainer((Alphabet)new ContinuousAlphabet());
        this.decay = 1.0;
    }

    public FixedCorrelationGaussianEmission clone() throws CloneNotSupportedException {
        FixedCorrelationGaussianEmission clone = (FixedCorrelationGaussianEmission)super.clone();
        clone.mean = (double[])this.mean.clone();
        clone.prePrec = (double[][])ArrayHandler.clone((Cloneable[])this.prePrec);
        clone.meanStat = (double[])this.meanStat.clone();
        clone.prcStat = (double[][])ArrayHandler.clone((Cloneable[])this.prcStat);
        clone.precTemp = (double[][])ArrayHandler.clone((Cloneable[])this.precTemp);
        clone.fullPrecMat = (double[][])ArrayHandler.clone((Cloneable[])this.fullPrecMat);
        clone.sds = (double[])this.sds.clone();
        clone.corrMat = (double[][])ArrayHandler.clone((Cloneable[])this.corrMat);
        clone.invCorrMat = (double[][])ArrayHandler.clone((Cloneable[])this.invCorrMat);
        clone.priorMean = (double[])this.priorMean.clone();
        clone.priorPrec = (double[])this.priorPrec.clone();
        clone.tempValues = (double[])this.tempValues.clone();
        clone.decay = 1.0;
        return clone;
    }

    @Override
    public StringBuffer toXML() {
        return null;
    }

    @Override
    public AlphabetContainer getAlphabetContainer() {
        return this.con;
    }

    @Override
    public void initializeFunctionRandomly() {
        RandomNumberGenerator rng = new RandomNumberGenerator();
        int i = 0;
        while (i < this.prePrec.length) {
            this.sds[i] = rng.nextGamma(this.essPrec, 1.0 / this.priorPrec[i]);
            this.prePrec[i][i] = 1.0 / this.sds[i];
            this.sds[i] = Math.sqrt(1.0 / this.sds[i]);
            ++i;
        }
        i = 0;
        while (i < this.prePrec.length) {
            int j = 0;
            while (j < i) {
                this.prePrec[i][j] = this.corrMat[i][j] * this.sds[i] * this.sds[j];
                this.prePrec[j][i] = this.prePrec[i][j];
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < this.prePrec.length) {
            System.out.println(Arrays.toString(this.prePrec[i]));
            ++i;
        }
        MultinormalCholeskyGen mcg = new MultinormalCholeskyGen(new NormalGen(new LFSR258()), this.priorMean, this.prePrec);
        mcg.nextPoint(this.mean);
        System.out.println("mean: " + Arrays.toString(this.mean));
        Inversion.compute(this.prePrec, this.precTemp);
        double[][] temp = this.prePrec;
        this.prePrec = this.precTemp;
        this.precTemp = temp;
        int i2 = 0;
        while (i2 < this.prePrec.length) {
            System.out.println(Arrays.toString(this.prePrec[i2]));
            ++i2;
        }
        if (this.meanFixed) {
            this.mean = (double[])this.priorMean.clone();
        }
        this.logDet = Math.log(Determinant.compute(this.prePrec));
    }

    private double getLogProbFor(double[] values) {
        double val = 0.0;
        int i = 0;
        while (i < values.length) {
            int j = 0;
            while (j < values.length) {
                val -= (values[i] - this.mean[i]) * this.prePrec[i][j] * (values[j] - this.mean[j]);
                ++j;
            }
            ++i;
        }
        val += this.logDet - (double)this.dimension * Math.log(Math.PI * 2);
        return val *= 0.5;
    }

    @Override
    public double getLogProbFor(boolean forward, int startPos, int endPos, Sequence seq) throws OperationNotSupportedException {
        double val = 0.0;
        if (this.dimension == 1) {
            int i = startPos;
            while (i <= endPos) {
                this.tempValues[0] = seq.continuousVal(i);
                val += this.getLogProbFor(this.tempValues);
                ++i;
            }
        } else {
            if (((MultiDimensionalSequence)seq).getNumberOfSequences() != this.dimension) {
                throw new OperationNotSupportedException();
            }
            int i = startPos;
            while (i <= endPos) {
                ((MultiDimensionalSequence)seq).fillContainer(this.tempValues, i);
                val += this.getLogProbFor(this.tempValues);
                ++i;
            }
        }
        return val;
    }

    @Override
    public double getLogPriorTerm() {
        int i;
        double lp = 0.0;
        if (!this.meanFixed) {
            i = 0;
            while (i < this.mean.length) {
                int j = 0;
                while (j < this.mean.length) {
                    lp += (this.priorMean[i] - this.mean[i]) * this.prePrec[i][j] * (this.priorMean[j] - this.mean[j]);
                    ++j;
                }
                ++i;
            }
            lp *= -1.0 / (2.0 * this.essMu);
        }
        i = 0;
        while (i < this.sds.length) {
            lp -= (2.0 * this.essPrec + 1.0) * Math.log(this.sds[i]) + this.priorPrec[i] / (this.sds[i] * this.sds[i]);
            ++i;
        }
        if (!this.meanFixed) {
            lp += 0.5 * this.logDet;
        }
        return lp;
    }

    @Override
    public void resetStatistic() {
        Arrays.fill(this.meanStat, 0.0);
        int i = 0;
        while (i < this.prcStat.length) {
            Arrays.fill(this.prcStat[i], 0.0);
            ++i;
        }
        this.numStat = 0.0;
    }

    @Override
    public void addToStatistic(boolean forward, int startPos, int endPos, double weight, Sequence seq) throws OperationNotSupportedException {
        if (this.dimension == 1) {
            int j = startPos;
            while (j <= endPos) {
                this.meanStat[0] = this.meanStat[0] + seq.continuousVal(j) * weight;
                double[] dArray = this.prcStat[0];
                dArray[0] = dArray[0] + seq.continuousVal(j) * seq.continuousVal(j) * weight;
                this.numStat += weight;
                ++j;
            }
        } else {
            if (((MultiDimensionalSequence)seq).getNumberOfSequences() != this.dimension) {
                throw new OperationNotSupportedException();
            }
            int j = startPos;
            while (j <= endPos) {
                ((MultiDimensionalSequence)seq).fillContainer(this.tempValues, j);
                int k = 0;
                while (k < this.tempValues.length) {
                    int n = k;
                    this.meanStat[n] = this.meanStat[n] + this.tempValues[k] * weight;
                    int m = 0;
                    while (m < this.prcStat[k].length) {
                        double[] dArray = this.prcStat[k];
                        int n2 = m;
                        dArray[n2] = dArray[n2] + this.tempValues[k] * this.tempValues[m] * weight;
                        ++m;
                    }
                    ++k;
                }
                this.numStat += weight;
                ++j;
            }
        }
    }

    @Override
    public void joinStatistics(Emission ... emissions) {
        int k;
        int j;
        int i = 0;
        while (i < emissions.length) {
            if (emissions[i] != this) {
                j = 0;
                while (j < this.meanStat.length) {
                    int n = j;
                    this.meanStat[n] = this.meanStat[n] + ((FixedCorrelationGaussianEmission)emissions[i]).meanStat[j];
                    k = 0;
                    while (k < this.prcStat[j].length) {
                        double[] dArray = this.prcStat[j];
                        int n2 = k;
                        dArray[n2] = dArray[n2] + ((FixedCorrelationGaussianEmission)emissions[i]).prcStat[j][k];
                        ++k;
                    }
                    ++j;
                }
                this.numStat += ((FixedCorrelationGaussianEmission)emissions[i]).numStat;
            }
            ++i;
        }
        i = 0;
        while (i < emissions.length) {
            if (emissions[i] != this) {
                j = 0;
                while (j < this.meanStat.length) {
                    ((FixedCorrelationGaussianEmission)emissions[i]).meanStat[j] = this.meanStat[j];
                    k = 0;
                    while (k < this.prcStat[j].length) {
                        ((FixedCorrelationGaussianEmission)emissions[i]).prcStat[j][k] = this.prcStat[j][k];
                        ++k;
                    }
                    ++j;
                }
                ((FixedCorrelationGaussianEmission)emissions[i]).numStat = this.numStat;
            }
            ++i;
        }
    }

    @Override
    public void estimateFromStatistic() {
        int j;
        int i;
        if (this.meanFixed) {
            this.mean = (double[])this.priorMean.clone();
        } else {
            i = 0;
            while (i < this.mean.length) {
                this.mean[i] = this.essMu * this.priorMean[i] + this.meanStat[i];
                int n = i++;
                this.mean[n] = this.mean[n] / (this.numStat + this.essMu);
            }
        }
        i = 0;
        while (i < this.mean.length) {
            j = 0;
            while (j < this.mean.length) {
                this.precTemp[i][j] = this.meanFixed ? 0.0 : (this.mean[i] - this.priorMean[i]) * (this.mean[j] - this.priorMean[j]);
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < this.sds.length) {
            this.sds[i] = this.prcStat[i][i] - 2.0 * this.mean[i] * this.meanStat[i] + this.numStat * this.mean[i] * this.mean[i];
            this.sds[i] = (this.invCorrMat[i][i] * (this.sds[i] + this.essMu * this.precTemp[i][i]) + 2.0 * this.priorPrec[i]) / (this.numStat + 2.0 * this.essPrec + 1.0 + (double)(!this.meanFixed ? 1 : 0));
            this.sds[i] = Math.sqrt(this.sds[i]);
            ++i;
        }
        i = 0;
        while (i < this.fullPrecMat.length) {
            j = 0;
            while (j < this.fullPrecMat[i].length) {
                this.fullPrecMat[i][j] = this.prcStat[i][j] - this.mean[i] * this.meanStat[j] - this.mean[j] * this.meanStat[i] + this.numStat * this.mean[i] * this.mean[j];
                this.fullPrecMat[i][j] = ((i == j ? this.priorPrec[i] : 0.0) + this.fullPrecMat[i][j] + this.essMu * this.precTemp[i][j]) / (this.essPrec - (double)this.dimension + this.numStat);
                ++j;
            }
            this.fullPrecMat[i][i] = this.sds[i] * this.sds[i];
            ++i;
        }
        double[] sdBack = (double[])this.sds.clone();
        int[] perm = new int[this.sds.length];
        double err = Double.POSITIVE_INFINITY;
        while (err > 1.0E-6) {
            err = 0.0;
            Permutation.random(perm);
            int k = 0;
            while (k < perm.length) {
                int i2 = perm[k];
                double temp = 0.0;
                double n = 0.0;
                int j2 = 0;
                while (j2 < this.fullPrecMat[i2].length) {
                    if (this.corrMat[i2][j2] != 0.0) {
                        temp = i2 == j2 ? (temp += Math.sqrt(this.fullPrecMat[i2][j2])) : (temp += this.fullPrecMat[i2][j2] / this.sds[j2] / this.corrMat[i2][j2]);
                        n += 1.0;
                    }
                    ++j2;
                }
                err += (this.sds[i2] - (temp /= n)) * (this.sds[i2] - temp);
                this.sds[i2] = temp;
                ++k;
            }
        }
        int i3 = 0;
        while (i3 < this.sds.length) {
            this.sds[i3] = (1.0 * this.sds[i3] + this.decay * sdBack[i3]) / (this.decay + 1.0);
            ++i3;
        }
        this.decay *= 2.0;
        i3 = 0;
        while (i3 < this.prePrec.length) {
            int j3 = 0;
            while (j3 <= i3) {
                this.prePrec[i3][j3] = this.corrMat[i3][j3] * this.sds[i3] * this.sds[j3];
                this.prePrec[j3][i3] = this.prePrec[i3][j3];
                ++j3;
            }
            ++i3;
        }
        Inversion.compute(this.prePrec, this.precTemp);
        double[][] temp = this.prePrec;
        this.prePrec = this.precTemp;
        this.precTemp = temp;
        this.logDet = Math.log(Determinant.compute(this.prePrec));
    }

    @Override
    public String getNodeShape(boolean forward) {
        return null;
    }

    @Override
    public String getNodeLabel(double weight, String name, NumberFormat nf) {
        return "G";
    }

    @Override
    public void setParameters(Emission t) throws IllegalArgumentException {
        int i = 0;
        while (i < this.mean.length) {
            this.mean[i] = ((FixedCorrelationGaussianEmission)t).mean[i];
            this.sds[i] = ((FixedCorrelationGaussianEmission)t).sds[i];
            int j = 0;
            while (j < this.prePrec[i].length) {
                this.prePrec[i][j] = ((FixedCorrelationGaussianEmission)t).prePrec[i][j];
                ++j;
            }
            ++i;
        }
        this.logDet = ((FixedCorrelationGaussianEmission)t).logDet;
    }

    public String toString() {
        String str = String.valueOf(Arrays.toString(this.mean)) + "\n\n";
        Inversion.compute(this.prePrec, this.precTemp);
        int i = 0;
        while (i < this.precTemp.length) {
            str = String.valueOf(str) + Arrays.toString(this.precTemp[i]);
            ++i;
        }
        return str;
    }

    @Override
    public /* synthetic */ String toString(NumberFormat numberFormat) {
        throw new Error("Unresolved compilation problem: \n\tThe type FixedCorrelationGaussianEmission must implement the inherited abstract method Emission.toString(NumberFormat)\n");
    }
}

