/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.continuous;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.DinucleotideProperty;
import de.jstacs.data.sequences.ArbitrarySequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.random.RandomNumberGenerator;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Random;

public class SingleGaussianDiffSM
extends AbstractDifferentiableStatisticalModel {
    private static RandomNumberGenerator rand = new RandomNumberGenerator();
    private double ess;
    private double priorMu;
    private double priorAlpha;
    private double priorBeta;
    private double mu;
    private double logPrecision;
    private double precision;
    private boolean initialized;
    private boolean alwaysInitRandomly;
    private double logNorm;
    private DinucleotideProperty prop;
    private boolean fixMu;

    public SingleGaussianDiffSM(double ess, double priorMu, double priorAlpha, double priorBeta, AlphabetContainer alphabet, boolean alwaysInitRandomly, boolean fixMu, double mu) {
        super(alphabet, 1);
        this.fixMu = fixMu;
        this.mu = mu;
        this.ess = ess;
        this.priorMu = priorMu;
        this.priorAlpha = priorAlpha;
        this.priorBeta = priorBeta;
        this.initialized = false;
        this.alwaysInitRandomly = alwaysInitRandomly;
    }

    public SingleGaussianDiffSM(AlphabetContainer alphabet, double ess, double priorMu, double expectedPrecision, double sdPrecision, boolean alwaysInitRandomly) {
        this(ess, priorMu, (expectedPrecision / (2.0 * sdPrecision * sdPrecision) + Math.sqrt(expectedPrecision / (2.0 * sdPrecision * sdPrecision) * (expectedPrecision / (2.0 * sdPrecision * sdPrecision)) + 1.0 / (sdPrecision * sdPrecision))) * expectedPrecision + 1.0, expectedPrecision / (2.0 * sdPrecision * sdPrecision) + Math.sqrt(expectedPrecision / (2.0 * sdPrecision * sdPrecision) * (expectedPrecision / (2.0 * sdPrecision * sdPrecision)) + 1.0 / (sdPrecision * sdPrecision)), alphabet, alwaysInitRandomly, false, 0.0);
    }

    public SingleGaussianDiffSM(StringBuffer buf) throws NonParsableException {
        super(buf);
    }

    @Override
    public SingleGaussianDiffSM clone() throws CloneNotSupportedException {
        return (SingleGaussianDiffSM)super.clone();
    }

    @Override
    public DataSet emitDataSet(int numberOfSequences, int ... seqLength) throws Exception {
        Sequence[] seqs = new Sequence[numberOfSequences];
        Random r = new Random();
        int l = seqLength[0];
        double sd = Math.sqrt(1.0 / this.precision);
        int i = 0;
        while (i < numberOfSequences) {
            double[] v = seqLength.length > 1 ? new double[seqLength[i]] : new double[l];
            int k = 0;
            while (k < v.length) {
                v[k] = r.nextDouble() * sd + this.mu;
                ++k;
            }
            seqs[i] = new ArbitrarySequence(this.alphabets, v);
            ++i;
        }
        return new DataSet("sampled from " + this.getInstanceName(), seqs);
    }

    public void setDinucleotideProperty(DinucleotideProperty prop) {
        this.prop = prop;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        this.length = 1;
        xml = XMLParser.extractForTag(xml, this.getClass().getSimpleName());
        this.alphabets = XMLParser.extractObjectForTags(xml, "alphabet", AlphabetContainer.class);
        this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
        this.priorMu = XMLParser.extractObjectForTags(xml, "priorMu", Double.TYPE);
        this.priorAlpha = XMLParser.extractObjectForTags(xml, "priorAlpha", Double.TYPE);
        this.priorBeta = XMLParser.extractObjectForTags(xml, "priorBeta", Double.TYPE);
        this.fixMu = XMLParser.extractObjectForTags(xml, "fixMu", Boolean.TYPE);
        this.mu = XMLParser.extractObjectForTags(xml, "mu", Double.TYPE);
        this.logPrecision = XMLParser.extractObjectForTags(xml, "logPrecision", Double.TYPE);
        this.precision = Math.exp(this.logPrecision);
        this.initialized = XMLParser.extractObjectForTags(xml, "initialized", Boolean.TYPE);
        this.prop = XMLParser.extractObjectForTags(xml, "prop", DinucleotideProperty.class);
        this.precomputeNormalization();
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer buf = new StringBuffer();
        XMLParser.appendObjectWithTags(buf, this.alphabets, "alphabet");
        XMLParser.appendObjectWithTags(buf, this.ess, "ess");
        XMLParser.appendObjectWithTags(buf, this.priorMu, "priorMu");
        XMLParser.appendObjectWithTags(buf, this.priorAlpha, "priorAlpha");
        XMLParser.appendObjectWithTags(buf, this.priorBeta, "priorBeta");
        XMLParser.appendObjectWithTags(buf, this.fixMu, "fixMu");
        XMLParser.appendObjectWithTags(buf, this.mu, "mu");
        XMLParser.appendObjectWithTags(buf, this.logPrecision, "logPrecision");
        XMLParser.appendObjectWithTags(buf, this.initialized, "initialized");
        XMLParser.appendObjectWithTags(buf, (Object)this.prop, "prop");
        XMLParser.addTags(buf, this.getClass().getSimpleName());
        return buf;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        double val = this.mu - this.priorMu;
        double gradmu = this.ess * this.precision * val;
        if (!this.fixMu) {
            int n = start++;
            grad[n] = grad[n] - gradmu;
        }
        int n = start;
        grad[n] = grad[n] + (0.5 - 0.5 * gradmu * val + this.priorAlpha - this.priorBeta * this.precision);
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public double getLogPriorTerm() {
        double val = this.mu - this.priorMu;
        return 0.5 * (Math.log(this.ess / (Math.PI * 2)) + this.logPrecision - this.ess * this.precision * val * val) + this.priorAlpha * Math.log(this.priorBeta) - Gamma.logOfGamma((double)this.priorAlpha) + this.priorAlpha * this.logPrecision - this.priorBeta * this.precision;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return 0;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        if (this.fixMu) {
            return new double[]{this.logPrecision};
        }
        return new double[]{this.mu, this.logPrecision};
    }

    @Override
    public String getInstanceName() {
        return String.valueOf(this.getClass().getSimpleName()) + " with " + this.mu + " and " + this.precision;
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        if (Double.isNaN(seq.continuousVal(start))) {
            return 0.0;
        }
        double val = 0.0;
        val = this.prop == null ? seq.continuousVal(start) - this.mu : this.prop.getProperty(seq, start) - this.mu;
        return this.logNorm - 0.5 * val * val * this.precision;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        if (Double.isNaN(seq.continuousVal(start))) {
            return 0.0;
        }
        double val = 0.0;
        val = this.prop == null ? seq.continuousVal(start) - this.mu : this.prop.getProperty(seq, start) - this.mu;
        int idx = 0;
        if (!this.fixMu) {
            indices.add(idx++);
            partialDer.add(this.precision * val);
        }
        indices.add(idx);
        partialDer.add(0.5 * (1.0 - this.precision * val * val));
        return this.logNorm - 0.5 * val * val * this.precision;
    }

    @Override
    public int getNumberOfParameters() {
        return this.fixMu ? 1 : 2;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        if (this.alwaysInitRandomly) {
            this.initializeFunctionRandomly(freeParams);
        } else {
            double x = 0.0;
            double xsq = 0.0;
            double norm = 0.0;
            int i = 0;
            while (i < data[index].getNumberOfElements()) {
                double w;
                double d = w = weights == null || weights[index] == null ? 1.0 : weights[index][i];
                if (!Double.isNaN(data[index].getElementAt(i).continuousVal(0))) {
                    double temp = 0.0;
                    temp = this.prop == null ? data[index].getElementAt(i).continuousVal(0) - this.mu : this.prop.getProperty(data[index].getElementAt(i), 0) - this.mu;
                    x += w * temp;
                    xsq += w * temp * temp;
                    norm += w;
                }
                ++i;
            }
            double var = xsq / norm - x / norm * (x / norm);
            if (!this.fixMu) {
                this.mu = (x + this.ess * this.priorMu) / (norm + this.ess);
                if (Double.isNaN(this.mu) && (norm == 0.0 || Double.isNaN(norm))) {
                    this.mu = this.priorMu;
                }
            }
            double betap = this.priorBeta + 0.5 * norm * var + this.ess * norm * (x / norm - this.priorMu) * (x / norm - this.priorMu) / (2.0 * (this.ess + norm));
            this.precision = (2.0 * this.priorAlpha + norm - 1.0) / (2.0 * betap);
            if (Double.isNaN(this.precision) && (norm == 0.0 || Double.isNaN(norm))) {
                this.precision = (2.0 * this.priorAlpha - 1.0) / (2.0 * this.priorBeta);
            }
            this.logPrecision = Math.log(this.precision);
            this.precomputeNormalization();
            this.initialized = true;
        }
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        this.precision = rand.nextGamma(this.priorAlpha, 1.0 / this.priorBeta);
        this.logPrecision = Math.log(this.precision);
        if (!this.fixMu) {
            this.mu = rand.nextGaussian() / this.precision + this.priorMu;
        }
        this.precomputeNormalization();
        this.initialized = true;
    }

    @Override
    public boolean isInitialized() {
        return this.initialized;
    }

    @Override
    public void setParameters(double[] params, int start) {
        if (!this.fixMu) {
            this.mu = params[start++];
        }
        this.logPrecision = params[start];
        this.precision = Math.exp(this.logPrecision);
        this.precomputeNormalization();
    }

    private void precomputeNormalization() {
        this.logNorm = 0.5 * (this.logPrecision - Math.log(Math.PI * 2));
    }

    @Override
    public String toString(NumberFormat nf) {
        return String.valueOf(this.getClass().getSimpleName()) + " with mu=" + nf.format(this.mu) + " precision=" + nf.format(this.precision) + (this.prop == null ? "" : " prop=" + this.prop.name());
    }
}

