package de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior;

import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.utils.Normalisation;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;

/* loaded from: input_file:de/jstacs/classifiers/differentiableSequenceScoreBased/logPrior/CompositeLogPrior.class */
public class CompositeLogPrior extends LogPrior {
    private DifferentiableStatisticalModel[] function;
    private double fullEss;
    private double logGammaSum;
    private double[] ess;
    private double[] classPars;
    private double[] logPartNorm;
    private boolean freeParameters;

    public CompositeLogPrior() {
    }

    public CompositeLogPrior(StringBuffer stringBuffer) {
        this();
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior
    public void set(boolean z, DifferentiableSequenceScore... differentiableSequenceScoreArr) throws Exception {
        this.function = new DifferentiableStatisticalModel[differentiableSequenceScoreArr.length];
        this.ess = new double[differentiableSequenceScoreArr.length];
        this.classPars = new double[differentiableSequenceScoreArr.length];
        this.logPartNorm = new double[differentiableSequenceScoreArr.length];
        this.fullEss = 0.0d;
        this.logGammaSum = 0.0d;
        for (int i = 0; i < differentiableSequenceScoreArr.length; i++) {
            if (!(differentiableSequenceScoreArr[i] instanceof DifferentiableStatisticalModel)) {
                throw new Exception("Only DifferentiableStatisticalModel allowed.");
            }
            this.function[i] = (DifferentiableStatisticalModel) differentiableSequenceScoreArr[i];
            this.ess[i] = this.function[i].getESS();
            if (this.ess[i] == 0.0d) {
                throw new IllegalArgumentException("The ess of the function " + i + " is zero, but should be positive.");
            }
            this.fullEss += this.ess[i];
            this.logGammaSum -= Gamma.logOfGamma(this.ess[i]);
        }
        this.logGammaSum += Gamma.logOfGamma(this.fullEss);
        this.freeParameters = z;
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior
    public void addGradientFor(double[] dArr, double[] dArr2) throws EvaluationException {
        try {
            int length = this.function.length - (this.freeParameters ? 1 : 0);
            for (int i = 0; i < length; i++) {
                this.classPars[i] = dArr[i];
                this.logPartNorm[i] = this.classPars[i] + this.function[i].getLogNormalizationConstant();
            }
            if (this.freeParameters) {
                this.classPars[length] = 0.0d;
                this.logPartNorm[length] = this.function[length].getLogNormalizationConstant();
            }
            double logSumNormalisation = Normalisation.logSumNormalisation(this.logPartNorm);
            int i2 = 0;
            while (i2 < length) {
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + (this.ess[i2] - (this.fullEss * this.logPartNorm[i2]));
                i2++;
            }
            for (int i4 = 0; i4 < this.function.length; i4++) {
                this.function[i4].addGradientOfLogPriorTerm(dArr2, i2);
                int numberOfParameters = this.function[i4].getNumberOfParameters();
                int i5 = 0;
                while (i5 < numberOfParameters) {
                    int i6 = i2;
                    dArr2[i6] = dArr2[i6] - (this.fullEss * Math.exp((this.classPars[i4] + this.function[i4].getLogPartialNormalizationConstant(i5)) - logSumNormalisation));
                    i5++;
                    i2++;
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
            throw new EvaluationException(e.getMessage());
        }
    }

    @Override // de.jstacs.algorithms.optimization.Function
    public double evaluateFunction(double[] dArr) throws DimensionException, EvaluationException {
        try {
            double d = Double.NEGATIVE_INFINITY;
            double d2 = 0.0d;
            int length = this.function.length - (this.freeParameters ? 1 : 0);
            for (int i = 0; i < length; i++) {
                d = Normalisation.getLogSum(d, dArr[i] + this.function[i].getLogNormalizationConstant());
                d2 += (dArr[i] * this.ess[i]) + this.function[i].getLogPriorTerm();
            }
            if (this.freeParameters) {
                d = Normalisation.getLogSum(d, this.function[length].getLogNormalizationConstant());
                d2 += this.function[length].getLogPriorTerm();
            }
            return (this.logGammaSum - (this.fullEss * d)) + d2;
        } catch (Exception e) {
            e.printStackTrace();
            EvaluationException evaluationException = new EvaluationException(e.getCause().getMessage());
            evaluationException.setStackTrace(e.getStackTrace());
            throw evaluationException;
        }
    }

    @Override // de.jstacs.algorithms.optimization.Function
    public int getDimensionOfScope() {
        int length = this.function.length - (this.freeParameters ? 1 : 0);
        for (int i = 0; i < this.function.length; i++) {
            int numberOfParameters = this.function[i].getNumberOfParameters();
            if (numberOfParameters == -1) {
                return -1;
            }
            length += numberOfParameters;
        }
        return length;
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior
    public CompositeLogPrior getNewInstance() throws CloneNotSupportedException {
        return new CompositeLogPrior();
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior, de.jstacs.Storable
    public StringBuffer toXML() {
        return new StringBuffer(1);
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior
    public String getInstanceName() {
        return "Composite log prior";
    }
}
