/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.continuous;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.random.RandomNumberGenerator;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.LinkedList;
import javax.naming.OperationNotSupportedException;

public class GaussianEmission
implements DifferentiableEmission {
    private static RandomNumberGenerator rand = new RandomNumberGenerator();
    private double ess;
    private double priorMu;
    private double priorAlpha;
    private double priorBeta;
    private double mu;
    private double logPrecision;
    private double precision;
    private double mean;
    private double meansq;
    private double n;
    private double logNorm;
    private boolean transformed;
    private int offset;
    private AlphabetContainer con;

    public GaussianEmission(AlphabetContainer con) {
        this(con, 0.0, 0.0, 0.0, 0.0, false);
    }

    public GaussianEmission(AlphabetContainer con, double ess, double priorMu, double priorAlpha, double priorBeta, boolean transformed) {
        this.con = con;
        this.ess = ess;
        this.priorMu = priorMu;
        this.priorAlpha = priorAlpha;
        this.priorBeta = priorBeta;
        this.transformed = transformed;
    }

    public GaussianEmission(double ess, AlphabetContainer con, double priorMu, double expectedPrecision, double sdPrecision, boolean transformed) {
        this(con, ess, priorMu, (expectedPrecision / (2.0 * sdPrecision * sdPrecision) + Math.sqrt(expectedPrecision / (2.0 * sdPrecision * sdPrecision) * (expectedPrecision / (2.0 * sdPrecision * sdPrecision)) + 1.0 / (sdPrecision * sdPrecision))) * expectedPrecision + 1.0, expectedPrecision / (2.0 * sdPrecision * sdPrecision) + Math.sqrt(expectedPrecision / (2.0 * sdPrecision * sdPrecision) * (expectedPrecision / (2.0 * sdPrecision * sdPrecision)) + 1.0 / (sdPrecision * sdPrecision)), transformed);
    }

    public GaussianEmission(StringBuffer xml) throws NonParsableException {
        this.fromXML(xml);
    }

    public GaussianEmission clone() throws CloneNotSupportedException {
        GaussianEmission clone = (GaussianEmission)super.clone();
        return clone;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] gradient, int offset) {
        if (this.ess > 0.0) {
            double val = this.mu - this.priorMu;
            double gradmu = this.ess * this.precision * val;
            int n = this.offset + offset;
            gradient[n] = gradient[n] - gradmu;
            int n2 = this.offset + offset + 1;
            gradient[n2] = gradient[n2] + (0.5 - 0.5 * gradmu * val + this.priorAlpha - this.priorBeta * this.precision);
        }
    }

    @Override
    public void joinStatistics(Emission ... emissions) {
        int i;
        for (i = 0; i < emissions.length; ++i) {
            if (emissions[i] == this) continue;
            this.mean += ((GaussianEmission)emissions[i]).mean;
            this.meansq += ((GaussianEmission)emissions[i]).meansq;
            this.n += ((GaussianEmission)emissions[i]).n;
        }
        for (i = 0; i < emissions.length; ++i) {
            ((GaussianEmission)emissions[i]).mean = this.mean;
            ((GaussianEmission)emissions[i]).meansq = this.meansq;
            ((GaussianEmission)emissions[i]).n = this.n;
        }
    }

    @Override
    public void addToStatistic(boolean forward, int startPos, int endPos, double weight, Sequence seq) throws OperationNotSupportedException {
        if (!forward) {
            throw new OperationNotSupportedException();
        }
        while (startPos <= endPos) {
            double w = weight * seq.continuousVal(startPos);
            this.mean += w;
            this.meansq += w * seq.continuousVal(startPos);
            this.n += weight;
            ++startPos;
        }
    }

    @Override
    public void estimateFromStatistic() {
        if (this.ess == 0.0) {
            if (this.n == 0.0) {
                this.n = 1.0;
            }
            this.mu = this.mean / this.n;
            this.precision = this.meansq / this.n - this.mu * this.mu;
            this.precision = this.precision != 0.0 ? 1.0 / this.precision : 0.0;
        } else {
            this.mu = (this.mean + this.ess * this.priorMu) / (this.n + this.ess);
            double s = this.meansq - 2.0 * this.mu * this.mean + this.n * this.mu * this.mu;
            double s0 = (this.mu - this.priorMu) * (this.mu - this.priorMu);
            this.precision = this.transformed ? (this.n + 2.0 * this.priorAlpha + 1.0) / (s + 2.0 * this.priorBeta + this.ess * s0) : (this.n + 2.0 * this.priorAlpha - 1.0) / (s + 2.0 * this.priorBeta + this.ess * s0);
        }
        this.logPrecision = Math.log(this.precision);
        this.precompute();
    }

    @Override
    public void fillCurrentParameter(double[] params) {
        params[this.offset] = this.mu;
        params[this.offset + 1] = this.logPrecision;
    }

    @Override
    public void setParameter(double[] params, int offset) {
        this.mu = params[offset];
        this.logPrecision = params[offset + 1];
        this.precision = Math.exp(this.logPrecision);
        this.precompute();
    }

    @Override
    public int setParameterOffset(int offset) {
        this.offset = offset;
        return offset + 2;
    }

    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, this.getClass().getSimpleName());
        this.con = (AlphabetContainer)XMLParser.extractObjectForTags(xml, "alphabet");
        this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
        this.priorMu = XMLParser.extractObjectForTags(xml, "priorMu", Double.TYPE);
        this.priorAlpha = XMLParser.extractObjectForTags(xml, "priorAlpha", Double.TYPE);
        this.priorBeta = XMLParser.extractObjectForTags(xml, "priorBeta", Double.TYPE);
        this.mu = XMLParser.extractObjectForTags(xml, "mu", Double.TYPE);
        this.logPrecision = XMLParser.extractObjectForTags(xml, "logPrecision", Double.TYPE);
        this.precision = Math.exp(this.logPrecision);
        this.mean = XMLParser.extractObjectForTags(xml, "mean", Double.TYPE);
        this.meansq = XMLParser.extractObjectForTags(xml, "meansq", Double.TYPE);
        this.transformed = XMLParser.extractObjectForTags(xml, "transformed", Boolean.TYPE);
        this.n = XMLParser.extractObjectForTags(xml, "n", Double.TYPE);
        this.offset = XMLParser.extractObjectForTags(xml, "offset", Integer.TYPE);
        this.precompute();
    }

    @Override
    public double getLogPriorTerm() {
        if (this.ess > 0.0) {
            double val = this.mu - this.priorMu;
            return 0.5 * (Math.log(this.ess / (Math.PI * 2)) + this.logPrecision - this.ess * this.precision * val * val) + this.priorAlpha * Math.log(this.priorBeta) - Gamma.logOfGamma((double)this.priorAlpha) + this.priorAlpha * this.logPrecision - this.priorBeta * this.precision;
        }
        return 0.0;
    }

    @Override
    public double getLogProbAndPartialDerivationFor(boolean forward, int startPos, int endPos, IntList indices, DoubleList partDer, Sequence seq) throws OperationNotSupportedException {
        if (!forward) {
            throw new OperationNotSupportedException();
        }
        double score = 0.0;
        double derivmu = 0.0;
        double derivprc = 0.0;
        while (startPos <= endPos) {
            double val = seq.continuousVal(startPos) - this.mu;
            derivmu += this.precision * val;
            derivprc += 0.5 * (1.0 - this.precision * val * val);
            score += this.logNorm - 0.5 * val * val * this.precision;
            ++startPos;
        }
        indices.add(this.offset);
        partDer.add(derivmu);
        indices.add(this.offset + 1);
        partDer.add(derivprc);
        return score;
    }

    @Override
    public double getLogProbFor(boolean forward, int startPos, int endPos, Sequence seq) throws OperationNotSupportedException {
        double score = 0.0;
        while (startPos <= endPos) {
            double val = seq.continuousVal(startPos) - this.mu;
            score += this.logNorm - 0.5 * val * val * this.precision;
            ++startPos;
        }
        return score;
    }

    @Override
    public void initializeFunctionRandomly() {
        if (this.ess == 0.0) {
            this.precision = rand.nextGamma(1.0, 1.0);
            this.mu = rand.nextGaussian() / this.precision + this.priorMu;
        } else {
            this.precision = rand.nextGamma(this.priorAlpha, 1.0 / this.priorBeta);
            this.mu = rand.nextGaussian() / (this.ess * this.precision) + this.priorMu;
        }
        this.logPrecision = Math.log(this.precision);
        this.precompute();
    }

    protected void precompute() {
        this.logNorm = 0.5 * (this.logPrecision - Math.log(Math.PI * 2));
    }

    @Override
    public void resetStatistic() {
        this.n = 0.0;
        this.meansq = 0.0;
        this.mean = 0.0;
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer buf = new StringBuffer();
        XMLParser.appendObjectWithTags(buf, this.con, "alphabet");
        XMLParser.appendObjectWithTags(buf, this.ess, "ess");
        XMLParser.appendObjectWithTags(buf, this.priorMu, "priorMu");
        XMLParser.appendObjectWithTags(buf, this.priorAlpha, "priorAlpha");
        XMLParser.appendObjectWithTags(buf, this.priorBeta, "priorBeta");
        XMLParser.appendObjectWithTags(buf, this.mu, "mu");
        XMLParser.appendObjectWithTags(buf, this.logPrecision, "logPrecision");
        XMLParser.appendObjectWithTags(buf, this.mean, "mean");
        XMLParser.appendObjectWithTags(buf, this.meansq, "meansq");
        XMLParser.appendObjectWithTags(buf, this.n, "n");
        XMLParser.appendObjectWithTags(buf, this.transformed, "transformed");
        XMLParser.appendObjectWithTags(buf, this.offset, "offset");
        XMLParser.addTags(buf, this.getClass().getSimpleName());
        return buf;
    }

    @Override
    public AlphabetContainer getAlphabetContainer() {
        return this.con;
    }

    public String toString() {
        return "p = sqrt(" + this.precision + "/(2*pi)) * exp( -0.5 * " + this.precision + " * (x - " + this.mu + ")^2 );\n";
    }

    @Override
    public String getNodeShape(boolean forward) {
        return "\"box\"";
    }

    @Override
    public String getNodeLabel(double weight, String name, NumberFormat nf) {
        return "\"" + name + "\"";
    }

    @Override
    public void fillSamplingGroups(int parameterOffset, LinkedList<int[]> list) {
        list.add(new int[]{parameterOffset + this.offset, parameterOffset + this.offset + 1});
    }

    @Override
    public int getNumberOfParameters() {
        return 2;
    }

    @Override
    public int getSizeOfEventSpace() {
        return 0;
    }

    @Override
    public void setParameters(Emission t) throws IllegalArgumentException {
        if (!t.getClass().equals(this.getClass())) {
            throw new IllegalArgumentException("The transitions are not comparable.");
        }
        GaussianEmission tt = (GaussianEmission)t;
        this.mu = tt.mu;
        this.logPrecision = tt.logPrecision;
        this.precision = tt.precision;
        this.logNorm = tt.logNorm;
    }
}

