/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.continuous;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.gamma.GammaPriorFunction;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.gamma.NumericalIntegration;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.Random;

public class ExpGammaDiffSM
extends AbstractDifferentiableStatisticalModel {
    private static Random r = new Random();
    private double[] alphas;
    private double[] betas;
    private boolean isInitialized;
    private boolean plugin;
    private double ess;
    private double[] mua;
    private double[] mug;
    private double norm;
    private double priorNorm;
    private double[] alphaNorms;

    public ExpGammaDiffSM(AlphabetContainer alphabet, int length, double ess, double[] mua, double[] mug, boolean plugin) {
        super(alphabet, length);
        this.alphas = new double[length];
        this.betas = new double[length];
        this.alphaNorms = new double[length];
        this.isInitialized = false;
        this.plugin = plugin;
        this.ess = ess;
        this.mua = (double[])mua.clone();
        this.mug = (double[])mug.clone();
        this.priorNorm = 1.0;
    }

    public ExpGammaDiffSM(StringBuffer buf) throws NonParsableException {
        super(buf);
    }

    @Override
    public ExpGammaDiffSM clone() throws CloneNotSupportedException {
        ExpGammaDiffSM clone = (ExpGammaDiffSM)super.clone();
        clone.alphas = (double[])this.alphas.clone();
        clone.betas = (double[])this.betas.clone();
        clone.mua = (double[])this.mua.clone();
        clone.mug = (double[])this.mug.clone();
        clone.alphaNorms = (double[])this.alphaNorms.clone();
        return clone;
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer buf = new StringBuffer();
        XMLParser.appendObjectWithTags(buf, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(buf, this.length, "length");
        XMLParser.appendObjectWithTags(buf, this.alphas, "alphas");
        XMLParser.appendObjectWithTags(buf, this.betas, "betas");
        XMLParser.appendObjectWithTags(buf, this.isInitialized, "isInitialized");
        XMLParser.appendObjectWithTags(buf, this.plugin, "plugin");
        XMLParser.appendObjectWithTags(buf, this.ess, "ess");
        XMLParser.appendObjectWithTags(buf, this.mua, "mua");
        XMLParser.appendObjectWithTags(buf, this.mug, "mug");
        XMLParser.appendObjectWithTags(buf, this.norm, "norm");
        XMLParser.appendObjectWithTags(buf, this.priorNorm, "priorNorm");
        XMLParser.appendObjectWithTags(buf, this.alphaNorms, "alphaNorms");
        XMLParser.addTags(buf, this.getClass().getSimpleName());
        return buf;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, this.getClass().getSimpleName());
        this.alphabets = XMLParser.extractObjectForTags(xml, "alphabets", AlphabetContainer.class);
        this.length = XMLParser.extractObjectForTags(xml, "length", Integer.TYPE);
        this.alphas = XMLParser.extractObjectForTags(xml, "alphas", double[].class);
        this.betas = XMLParser.extractObjectForTags(xml, "betas", double[].class);
        this.isInitialized = XMLParser.extractObjectForTags(xml, "isInitialized", Boolean.TYPE);
        this.plugin = XMLParser.extractObjectForTags(xml, "plugin", Boolean.TYPE);
        this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
        this.mua = XMLParser.extractObjectForTags(xml, "mua", double[].class);
        this.mug = XMLParser.extractObjectForTags(xml, "mug", double[].class);
        this.norm = XMLParser.extractObjectForTags(xml, "norm", Double.TYPE);
        this.priorNorm = XMLParser.extractObjectForTags(xml, "priorNorm", Double.TYPE);
        this.alphaNorms = XMLParser.extractObjectForTags(xml, "alphaNorms", double[].class);
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int i = 0;
        while (i < this.alphas.length) {
            try {
                int n = start + i;
                grad[n] = grad[n] + (this.ess * (-this.alphas[i] * this.digamma(this.alphas[i]) + Math.log(this.betas[i]) * this.alphas[i] + this.alphas[i] * Math.log(this.mug[i])) + 1.0);
            }
            catch (StackOverflowError e) {
                System.out.println(this.alphas[i]);
                throw e;
            }
            int n = start + this.alphas.length + i;
            grad[n] = grad[n] + (this.ess * (this.alphas[i] - this.mua[i] * this.betas[i]) + 1.0);
            ++i;
        }
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public double getLogPriorTerm() {
        double val = 0.0;
        int i = 0;
        while (i < this.alphas.length) {
            double temp = -Gamma.logOfGamma(this.alphas[i]) + Math.log(this.betas[i]) * this.alphas[i] - this.mua[i] * this.betas[i] + this.alphas[i] * Math.log(this.mug[i]);
            temp *= this.ess;
            val += (temp += Math.log(this.alphas[i]) + Math.log(this.betas[i]));
            ++i;
        }
        return val - this.priorNorm;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return 1;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] pars = new double[this.alphas.length + this.betas.length];
        int i = 0;
        while (i < this.alphas.length) {
            pars[i] = Math.log(this.alphas[i]);
            pars[i + this.alphas.length] = Math.log(this.betas[i]);
            ++i;
        }
        return pars;
    }

    @Override
    public String getInstanceName() {
        return this.getClass().getSimpleName();
    }

    private void precomputePriorNormalization() {
        this.priorNorm = 0.0;
        if (this.ess > 0.0) {
            try {
                int i = 0;
                while (i < this.alphas.length) {
                    double temp = NumericalIntegration.getIntegralByNestedIntervals(new GammaPriorFunction(this.mua[i], this.mug[i], this.ess), 1.0E-10, 0.01, 0.001);
                    this.priorNorm += Math.log(temp);
                    ++i;
                }
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    private void precomputeNormalization() {
        this.norm = 0.0;
        int i = 0;
        while (i < this.alphas.length) {
            this.norm += this.alphas[i] * Math.log(this.betas[i]) - Gamma.logOfGamma(this.alphas[i]);
            this.alphaNorms[i] = this.alphas[i] * Math.log(this.betas[i]) - this.alphas[i] * this.digamma(this.alphas[i]);
            ++i;
        }
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        double val = this.norm;
        int i = 0;
        while (i < this.alphas.length) {
            double cv = seq.continuousVal(i + start);
            val += (this.alphas[i] - 1.0) * Math.log(cv) - this.betas[i] * cv;
            ++i;
        }
        return val;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        double val = this.norm;
        int i = 0;
        while (i < this.alphas.length) {
            double cv = seq.continuousVal(i + start);
            val += (this.alphas[i] - 1.0) * Math.log(cv) - this.betas[i] * cv;
            indices.add(i);
            partialDer.add(this.alphaNorms[i] + this.alphas[i] * Math.log(cv));
            indices.add(i + this.alphas.length);
            partialDer.add(this.alphas[i] - this.betas[i] * cv);
            ++i;
        }
        return val;
    }

    @Override
    public int getNumberOfParameters() {
        return this.alphas.length + this.betas.length;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        if (this.plugin) {
            Arrays.fill(this.alphas, 0.0);
            Arrays.fill(this.betas, 0.0);
            double norm = 0.0;
            int i = 0;
            while (i < data[index].getNumberOfElements()) {
                Sequence seq = data[index].getElementAt(i);
                double w = weights == null || weights[index] == null ? 1.0 : weights[index][i];
                int j = 0;
                while (j < seq.getLength()) {
                    int n = j;
                    this.alphas[n] = this.alphas[n] + w * seq.continuousVal(j);
                    int n2 = j;
                    this.betas[n2] = this.betas[n2] + w * seq.continuousVal(j) * seq.continuousVal(j);
                    ++j;
                }
                norm += w;
                ++i;
            }
            i = 0;
            while (i < this.alphas.length) {
                double k;
                int n = i;
                this.alphas[n] = this.alphas[n] / norm;
                int n3 = i;
                this.betas[n3] = this.betas[n3] / norm;
                double theta = this.betas[i] / this.alphas[i] - this.alphas[i];
                this.alphas[i] = k = (this.betas[i] - this.alphas[i] * this.alphas[i] - this.alphas[i]) / (theta * (theta - 1.0));
                if (Double.isInfinite(this.alphas[i])) {
                    this.alphas[i] = 1.0;
                }
                this.betas[i] = 1.0 / theta;
                if (Double.isInfinite(this.betas[i])) {
                    this.betas[i] = 1.0;
                }
                ++i;
            }
            this.precomputeNormalization();
            this.isInitialized = true;
        } else {
            this.initializeFunctionRandomly(freeParams);
        }
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        int i = 0;
        while (i < this.alphas.length) {
            this.alphas[i] = r.nextDouble() * 10.0 + Double.MIN_VALUE;
            this.betas[i] = r.nextDouble();
            ++i;
        }
        this.precomputeNormalization();
        this.isInitialized = true;
    }

    @Override
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override
    public void setParameters(double[] params, int start) {
        int i = 0;
        while (i < this.alphas.length) {
            this.alphas[i] = Math.exp(params[start + i]);
            this.betas[i] = Math.exp(params[start + this.alphas.length + i]);
            ++i;
        }
        this.precomputeNormalization();
    }

    @Override
    public String toString(NumberFormat nf) {
        String str = String.valueOf(this.getClass().getSimpleName()) + "\nalphas: " + Arrays.toString(this.alphas) + "\n";
        str = String.valueOf(str) + "betas: " + Arrays.toString(this.betas) + "\n";
        return str;
    }

    private double digamma(double x) {
        if (Double.isNaN(x) || Double.isInfinite(x)) {
            System.out.flush();
            System.out.println(Arrays.toString(this.alphas));
            System.out.println(Arrays.toString(this.betas));
            System.out.println(Arrays.toString(this.mua));
            System.out.println(Arrays.toString(this.mug));
            System.out.println(String.valueOf(this.norm) + ", " + this.priorNorm + "," + Arrays.toString(this.alphaNorms));
            System.out.flush();
            throw new RuntimeException("argument of digamma is " + x);
        }
        double[][] C7 = new double[][]{{13524.999667726346, 45285.60169954729, 45135.168469736665, 18529.01181858261, 3329.1525149406934, 240.68032474357202, 5.157789200013909, 0.006228350691898475}, {6.938911175376345E-7, 19768.574263046736, 41255.16083535383, 29390.287119932684, 9081.966607485518, 1244.7477785670856, 67.4291295163786, 1.0}};
        double[][] C4 = new double[][]{{-2.7281757513152966E-15, -0.6481571237661965, -4.486165439180193, -7.016772277667586, -2.1294044513101054}, {7.777885485229616, 54.61177381032151, 89.29207004818613, 32.270349379114336, 1.0}};
        double prodPj = 0.0;
        double prodQj = 0.0;
        double digX = 0.0;
        if (x >= 3.0) {
            double x2 = 1.0 / (x * x);
            int j = 4;
            while (j >= 0) {
                prodPj = prodPj * x2 + C4[0][j];
                prodQj = prodQj * x2 + C4[1][j];
                --j;
            }
            digX = Math.log(x) - 0.5 / x + prodPj / prodQj;
        } else if (x >= 0.5) {
            double X0 = 1.4616321449683622;
            int j = 7;
            while (j >= 0) {
                prodPj = x * prodPj + C7[0][j];
                prodQj = x * prodQj + C7[1][j];
                --j;
            }
            digX = (x - 1.4616321449683622) * (prodPj / prodQj);
        } else {
            double f = 1.0 - x - Math.floor(1.0 - x);
            digX = this.digamma(1.0 - x) + Math.PI / Math.tan(Math.PI * f);
        }
        return digX;
    }
}

