/*
 * Decompiled with CFR 0.152.
 */
package projects.kmermotifs;

import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.random.RandomNumberGenerator;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.util.Arrays;
import projects.kmermotifs.PositionStatisticsFunction;

public class GaussianPositionStatistic
implements PositionStatisticsFunction {
    private static RandomNumberGenerator rand = new RandomNumberGenerator();
    private static final double log2pi = Math.log(Math.PI * 2);
    private double logPrec;
    private double prec;
    private double mu;
    private double ess;
    private double priorMu;
    private double priorAlpha;
    private double priorBeta;

    public GaussianPositionStatistic(double ess, double priorMu, double[] priorHyp) {
        this.ess = ess;
        this.priorMu = priorMu;
        this.priorAlpha = priorHyp[0];
        this.priorBeta = priorHyp[1];
    }

    public GaussianPositionStatistic(double ess, double priorMu, double expectedPrecision, double sdPrecision) {
        this(ess, priorMu, new double[]{(expectedPrecision / (2.0 * sdPrecision * sdPrecision) + Math.sqrt(expectedPrecision / (2.0 * sdPrecision * sdPrecision) * (expectedPrecision / (2.0 * sdPrecision * sdPrecision)) + 1.0 / (sdPrecision * sdPrecision))) * expectedPrecision + 1.0, expectedPrecision / (2.0 * sdPrecision * sdPrecision) + Math.sqrt(expectedPrecision / (2.0 * sdPrecision * sdPrecision) * (expectedPrecision / (2.0 * sdPrecision * sdPrecision)) + 1.0 / (sdPrecision * sdPrecision))});
    }

    @Override
    public GaussianPositionStatistic clone() throws CloneNotSupportedException {
        GaussianPositionStatistic clone = (GaussianPositionStatistic)super.clone();
        return clone;
    }

    @Override
    public StringBuffer toXML() {
        return null;
    }

    @Override
    public Object getEmptyStatistics() {
        return new double[2];
    }

    @Override
    public void addToStatistics(Object stat, int position, double weight) throws IllegalArgumentException {
        double[] temp = (double[])stat;
        temp[0] = temp[0] + (double)position * weight;
        temp[1] = temp[1] + (double)(position * position) * weight;
    }

    @Override
    public void normalize(Object stat, double norm) {
        double[] temp = (double[])stat;
        temp[0] = temp[0] / norm;
        temp[1] = temp[1] / norm;
    }

    @Override
    public void addToStatistic(Object stat, Object toAdd) throws IllegalArgumentException {
        double[] temp = (double[])stat;
        double[] temp2 = (double[])toAdd;
        temp[0] = temp[0] + temp2[0];
        temp[1] = temp[1] + temp2[1];
    }

    @Override
    public double getLogScoreFor(Object stat, double n, int shift) throws IllegalArgumentException {
        double[] temp = (double[])stat;
        temp = (double[])temp.clone();
        GaussianPositionStatistic.adjust(temp, shift, n);
        return 0.5 * (this.logPrec - log2pi) - this.prec / 2.0 * (temp[1] - 2.0 * this.mu * temp[0] + this.mu * this.mu);
    }

    @Override
    public double getLogScoreAndPartialDerivationFor(Object stat, double n, int shift, IntList indices, DoubleList partialDer) {
        double[] temp = (double[])stat;
        temp = (double[])temp.clone();
        GaussianPositionStatistic.adjust(temp, shift, n);
        indices.add(0);
        partialDer.add(this.prec * (temp[0] - this.mu));
        indices.add(1);
        double score = this.prec / 2.0 * (temp[1] - 2.0 * this.mu * temp[0] + this.mu * this.mu);
        partialDer.add(0.5 - score);
        return 0.5 * (this.logPrec - log2pi) - score;
    }

    private static final void adjust(double[] stat, int shift, double n) {
        stat[1] = stat[1] + (2.0 * (double)shift * stat[0] + n * (double)shift * (double)shift);
        stat[1] = stat[1] / n;
        stat[0] = stat[0] + (double)shift * n;
        stat[0] = stat[0] / n;
    }

    @Override
    public double getLogPriorTerm() {
        if (this.ess > 0.0) {
            double val = this.mu - this.priorMu;
            return 0.5 * (Math.log(this.ess / (Math.PI * 2)) + this.logPrec - this.ess * this.prec * val * val) + this.priorAlpha * Math.log(this.priorBeta) - Gamma.logOfGamma(this.priorAlpha) + this.priorAlpha * this.logPrec - this.priorBeta * this.prec;
        }
        return 0.0;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] gradient, int start) {
        if (this.ess > 0.0) {
            double val = this.mu - this.priorMu;
            double gradmu = this.ess * this.prec * val;
            int n = start;
            gradient[n] = gradient[n] - gradmu;
            int n2 = start + 1;
            gradient[n2] = gradient[n2] + (0.5 - 0.5 * gradmu * val + this.priorAlpha - this.priorBeta * this.prec);
        }
    }

    @Override
    public int getNumberOfParameters() {
        return 2;
    }

    @Override
    public void setParameters(double[] pars, int start) {
        this.mu = pars[start];
        this.logPrec = pars[start + 1];
        this.prec = Math.exp(this.logPrec);
    }

    @Override
    public double[] getCurrentParameterValues() {
        return new double[]{this.mu, this.logPrec};
    }

    @Override
    public String toString(Object stat) {
        return Arrays.toString((double[])stat);
    }

    @Override
    public void initializeFunction(Object[] stats, double[] weights) {
        this.initializeFunctionRandomly();
    }

    @Override
    public void initializeFunctionRandomly() {
        if (this.ess == 0.0) {
            this.prec = rand.nextGamma(1.0, 1.0);
            this.mu = rand.nextGaussian() / this.prec + this.priorMu;
        } else {
            this.prec = rand.nextGamma(this.priorAlpha, 1.0 / this.priorBeta);
            this.mu = rand.nextGaussian() / (this.ess * this.prec) + this.priorMu;
        }
        this.logPrec = Math.log(this.prec);
        System.out.println("mu: " + this.mu + ", prec: " + this.prec);
    }

    public String toString() {
        return "N(" + this.mu + ", " + this.prec + ")";
    }
}

