package de.jstacs.sequenceScores.statisticalModels.differentiable.continuous;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.gamma.GammaPriorFunction;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.gamma.NumericalIntegration;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.Random;
import projects.dispom.DispomParameterSet;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/continuous/ExpGammaDiffSM.class */
public class ExpGammaDiffSM extends AbstractDifferentiableStatisticalModel {
    private static Random r = new Random();
    private double[] alphas;
    private double[] betas;
    private boolean isInitialized;
    private boolean plugin;
    private double ess;
    private double[] mua;
    private double[] mug;
    private double norm;
    private double priorNorm;
    private double[] alphaNorms;

    public ExpGammaDiffSM(int i, double d, double[] dArr, double[] dArr2, boolean z) {
        super(new AlphabetContainer(new ContinuousAlphabet()), i);
        this.alphas = new double[i];
        this.betas = new double[i];
        this.alphaNorms = new double[i];
        this.isInitialized = false;
        this.plugin = z;
        this.ess = d;
        this.mua = (double[]) dArr.clone();
        this.mug = (double[]) dArr2.clone();
        this.priorNorm = 1.0d;
    }

    public ExpGammaDiffSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone */
    public ExpGammaDiffSM mo114clone() throws CloneNotSupportedException {
        ExpGammaDiffSM expGammaDiffSM = (ExpGammaDiffSM) super.mo114clone();
        expGammaDiffSM.alphas = (double[]) this.alphas.clone();
        expGammaDiffSM.betas = (double[]) this.betas.clone();
        expGammaDiffSM.mua = (double[]) this.mua.clone();
        expGammaDiffSM.mug = (double[]) this.mug.clone();
        expGammaDiffSM.alphaNorms = (double[]) this.alphaNorms.clone();
        return expGammaDiffSM;
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer();
        XMLParser.appendObjectWithTags(stringBuffer, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.length), DispomParameterSet.LENGTH);
        XMLParser.appendObjectWithTags(stringBuffer, this.alphas, "alphas");
        XMLParser.appendObjectWithTags(stringBuffer, this.betas, "betas");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.isInitialized), "isInitialized");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.plugin), "plugin");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.ess), "ess");
        XMLParser.appendObjectWithTags(stringBuffer, this.mua, "mua");
        XMLParser.appendObjectWithTags(stringBuffer, this.mug, "mug");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.norm), "norm");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.priorNorm), "priorNorm");
        XMLParser.appendObjectWithTags(stringBuffer, this.alphaNorms, "alphaNorms");
        XMLParser.addTags(stringBuffer, getClass().getSimpleName());
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, getClass().getSimpleName());
        this.alphabets = (AlphabetContainer) XMLParser.extractObjectForTags(extractForTag, "alphabets", AlphabetContainer.class);
        this.length = ((Integer) XMLParser.extractObjectForTags(extractForTag, DispomParameterSet.LENGTH, Integer.TYPE)).intValue();
        this.alphas = (double[]) XMLParser.extractObjectForTags(extractForTag, "alphas", double[].class);
        this.betas = (double[]) XMLParser.extractObjectForTags(extractForTag, "betas", double[].class);
        this.isInitialized = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "isInitialized", Boolean.TYPE)).booleanValue();
        this.plugin = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "plugin", Boolean.TYPE)).booleanValue();
        this.ess = ((Double) XMLParser.extractObjectForTags(extractForTag, "ess", Double.TYPE)).doubleValue();
        this.mua = (double[]) XMLParser.extractObjectForTags(extractForTag, "mua", double[].class);
        this.mug = (double[]) XMLParser.extractObjectForTags(extractForTag, "mug", double[].class);
        this.norm = ((Double) XMLParser.extractObjectForTags(extractForTag, "norm", Double.TYPE)).doubleValue();
        this.priorNorm = ((Double) XMLParser.extractObjectForTags(extractForTag, "priorNorm", Double.TYPE)).doubleValue();
        this.alphaNorms = (double[]) XMLParser.extractObjectForTags(extractForTag, "alphaNorms", double[].class);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) throws Exception {
        for (int i2 = 0; i2 < this.alphas.length; i2++) {
            try {
                int i3 = i + i2;
                dArr[i3] = dArr[i3] + (this.ess * (((-this.alphas[i2]) * digamma(this.alphas[i2])) + (Math.log(this.betas[i2]) * this.alphas[i2]) + (this.alphas[i2] * Math.log(this.mug[i2])))) + 1.0d;
                int length = i + this.alphas.length + i2;
                dArr[length] = dArr[length] + (this.ess * (this.alphas[i2] - (this.mua[i2] * this.betas[i2]))) + 1.0d;
            } catch (StackOverflowError e) {
                System.out.println(this.alphas[i2]);
                throw e;
            }
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        return this.ess;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        double d = 0.0d;
        for (int i = 0; i < this.alphas.length; i++) {
            d += (((((-Gamma.logOfGamma(this.alphas[i])) + (Math.log(this.betas[i]) * this.alphas[i])) - (this.mua[i] * this.betas[i])) + (this.alphas[i] * Math.log(this.mug[i]))) * this.ess) + Math.log(this.alphas[i]) + Math.log(this.betas[i]);
        }
        return d - this.priorNorm;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogNormalizationConstant() {
        return 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogPartialNormalizationConstant(int i) {
        return Double.NEGATIVE_INFINITY;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int i) {
        return 1;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        double[] dArr = new double[this.alphas.length + this.betas.length];
        for (int i = 0; i < this.alphas.length; i++) {
            dArr[i] = Math.log(this.alphas[i]);
            dArr[i + this.alphas.length] = Math.log(this.betas[i]);
        }
        return dArr;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return getClass().getSimpleName();
    }

    private void precomputePriorNormalization() {
        this.priorNorm = 0.0d;
        if (this.ess > 0.0d) {
            for (int i = 0; i < this.alphas.length; i++) {
                try {
                    this.priorNorm += Math.log(NumericalIntegration.getIntegralByNestedIntervals(new GammaPriorFunction(this.mua[i], this.mug[i], this.ess), 1.0E-10d, 0.01d, 0.001d));
                } catch (Exception e) {
                    e.printStackTrace();
                    return;
                }
            }
        }
    }

    private void precomputeNormalization() {
        this.norm = 0.0d;
        for (int i = 0; i < this.alphas.length; i++) {
            this.norm += (this.alphas[i] * Math.log(this.betas[i])) - Gamma.logOfGamma(this.alphas[i]);
            this.alphaNorms[i] = (this.alphas[i] * Math.log(this.betas[i])) - (this.alphas[i] * digamma(this.alphas[i]));
        }
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i) {
        double d = this.norm;
        for (int i2 = 0; i2 < this.alphas.length; i2++) {
            double continuousVal = sequence.continuousVal(i2 + i) + 1.0E-10d;
            d += ((this.alphas[i2] - 1.0d) * Math.log(continuousVal)) - (this.betas[i2] * continuousVal);
            if (Double.isInfinite(d) || Double.isNaN(d)) {
                System.out.println(String.valueOf(d) + " " + Arrays.toString(this.alphas) + " " + Arrays.toString(this.betas) + " " + continuousVal);
            }
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        double d = this.norm;
        for (int i2 = 0; i2 < this.alphas.length; i2++) {
            double continuousVal = sequence.continuousVal(i2 + i) + 1.0E-10d;
            d += ((this.alphas[i2] - 1.0d) * Math.log(continuousVal)) - (this.betas[i2] * continuousVal);
            intList.add(i2);
            doubleList.add(this.alphaNorms[i2] + (this.alphas[i2] * Math.log(continuousVal)));
            intList.add(i2 + this.alphas.length);
            doubleList.add(this.alphas[i2] - (this.betas[i2] * continuousVal));
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        return this.alphas.length + this.betas.length;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        if (!this.plugin) {
            initializeFunctionRandomly(z);
            return;
        }
        Arrays.fill(this.alphas, 0.0d);
        Arrays.fill(this.betas, 0.0d);
        double d = 0.0d;
        for (int i2 = 0; i2 < dataSetArr[i].getNumberOfElements(); i2++) {
            Sequence elementAt = dataSetArr[i].getElementAt(i2);
            double d2 = (dArr == null || dArr[i] == null) ? 1.0d : dArr[i][i2];
            for (int i3 = 0; i3 < elementAt.getLength(); i3++) {
                double[] dArr2 = this.alphas;
                int i4 = i3;
                dArr2[i4] = dArr2[i4] + (d2 * elementAt.continuousVal(i3));
                double[] dArr3 = this.betas;
                int i5 = i3;
                dArr3[i5] = dArr3[i5] + (d2 * elementAt.continuousVal(i3) * elementAt.continuousVal(i3));
            }
            d += d2;
        }
        for (int i6 = 0; i6 < this.alphas.length; i6++) {
            double[] dArr4 = this.alphas;
            int i7 = i6;
            dArr4[i7] = dArr4[i7] / d;
            double[] dArr5 = this.betas;
            int i8 = i6;
            dArr5[i8] = dArr5[i8] / d;
            double d3 = (this.betas[i6] / this.alphas[i6]) - this.alphas[i6];
            this.alphas[i6] = ((this.betas[i6] - (this.alphas[i6] * this.alphas[i6])) - this.alphas[i6]) / (d3 * (d3 - 1.0d));
            if (Double.isInfinite(this.alphas[i6])) {
                this.alphas[i6] = 1.0d;
            }
            this.betas[i6] = 1.0d / d3;
            if (Double.isInfinite(this.betas[i6])) {
                this.betas[i6] = 1.0d;
            }
        }
        precomputeNormalization();
        this.isInitialized = true;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        for (int i = 0; i < this.alphas.length; i++) {
            this.alphas[i] = (r.nextDouble() * 10.0d) + Double.MIN_VALUE;
            this.betas[i] = r.nextDouble();
        }
        precomputeNormalization();
        this.isInitialized = true;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        for (int i2 = 0; i2 < this.alphas.length; i2++) {
            this.alphas[i2] = Math.exp(dArr[i + i2]);
            this.betas[i2] = Math.exp(dArr[i + this.alphas.length + i2]);
        }
        precomputeNormalization();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        return String.valueOf(String.valueOf(getClass().getSimpleName()) + "\nalphas: " + Arrays.toString(this.alphas) + "\n") + "betas: " + Arrays.toString(this.betas) + "\n";
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double digamma(double d) {
        double digamma;
        if (Double.isNaN(d) || Double.isInfinite(d)) {
            System.out.flush();
            System.out.println(Arrays.toString(this.alphas));
            System.out.println(Arrays.toString(this.betas));
            System.out.println(Arrays.toString(this.mua));
            System.out.println(Arrays.toString(this.mug));
            System.out.println(String.valueOf(this.norm) + ", " + this.priorNorm + "," + Arrays.toString(this.alphaNorms));
            System.out.flush();
            throw new RuntimeException("argument of digamma is " + d);
        }
        double[] dArr = {new double[]{13524.999667726346d, 45285.60169954729d, 45135.168469736665d, 18529.01181858261d, 3329.1525149406934d, 240.68032474357202d, 5.157789200013909d, 0.006228350691898475d}, new double[]{6.938911175376345E-7d, 19768.574263046736d, 41255.16083535383d, 29390.287119932684d, 9081.966607485518d, 1244.7477785670856d, 67.4291295163786d, 1.0d}};
        double[] dArr2 = {new double[]{-2.7281757513152966E-15d, -0.6481571237661965d, -4.486165439180193d, -7.016772277667586d, -2.1294044513101054d}, new double[]{7.777885485229616d, 54.61177381032151d, 89.29207004818613d, 32.270349379114336d, 1.0d}};
        double d2 = 0.0d;
        double d3 = 0.0d;
        if (d >= 3.0d) {
            double d4 = 1.0d / (d * d);
            for (int i = 4; i >= 0; i--) {
                d2 = (d2 * d4) + dArr2[0][i];
                d3 = (d3 * d4) + dArr2[1][i];
            }
            digamma = (Math.log(d) - (0.5d / d)) + (d2 / d3);
        } else if (d >= 0.5d) {
            for (int i2 = 7; i2 >= 0; i2--) {
                d2 = (d * d2) + dArr[0][i2];
                d3 = (d * d3) + dArr[1][i2];
            }
            digamma = (d - 1.4616321449683622d) * (d2 / d3);
        } else {
            digamma = digamma(1.0d - d) + (3.141592653589793d / Math.tan(3.141592653589793d * ((1.0d - d) - Math.floor(1.0d - d))));
        }
        return digamma;
    }
}
