package de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix;

import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifiers.differentiableSequenceScoreBased.DiffSSBasedOptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.utils.Normalisation;
import java.util.Arrays;

/* loaded from: input_file:de/jstacs/classifiers/differentiableSequenceScoreBased/gendismix/LogGenDisMixFunction.class */
public class LogGenDisMixFunction extends DiffSSBasedOptimizableFunction {
    protected double[][] helpArray;
    protected double[][] llGrad;
    protected double[][] cllGrad;
    protected double[] beta;
    protected double[] prGrad;

    public LogGenDisMixFunction(int i, DifferentiableSequenceScore[] differentiableSequenceScoreArr, DataSet[] dataSetArr, double[][] dArr, LogPrior logPrior, double[] dArr2, boolean z, boolean z2) throws IllegalArgumentException {
        super(i, differentiableSequenceScoreArr, dataSetArr, dArr, logPrior, z, z2);
        if (this.cl < 1 || (dArr2[0] != 0.0d && this.cl < 2)) {
            throw new IllegalArgumentException("The number of classes is not correct. You can use this class for the generative training of one class or the (in some kind) discriminative training of more than one class.");
        }
        this.beta = LearningPrinciple.checkWeights(dArr2);
        check();
        this.helpArray = new double[i][Math.max(2, this.cl)];
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected double[] joinGradients() throws EvaluationException {
        for (int i = 0; i < this.llGrad[0].length; i++) {
            for (int i2 = 1; i2 < this.llGrad.length; i2++) {
                double[] dArr = this.llGrad[0];
                int i3 = i;
                dArr[i3] = dArr[i3] + this.llGrad[i2][i];
                double[] dArr2 = this.cllGrad[0];
                int i4 = i;
                dArr2[i4] = dArr2[i4] + this.cllGrad[i2][i];
            }
        }
        if (this.beta[1] != 0.0d) {
            double d = Double.NEGATIVE_INFINITY;
            for (int i5 = 0; i5 < this.cl; i5++) {
                try {
                    d = Normalisation.getLogSum(d, this.logClazz[i5] + ((DifferentiableStatisticalModel) this.score[0][i5]).getLogNormalizationConstant());
                } catch (Exception e) {
                    EvaluationException evaluationException = new EvaluationException(e.getMessage());
                    evaluationException.setStackTrace(e.getStackTrace());
                    throw evaluationException;
                }
            }
            for (int i6 = 0; i6 < this.cl; i6++) {
                if (i6 < this.shortcut[0]) {
                    double[] dArr3 = this.llGrad[0];
                    int i7 = i6;
                    dArr3[i7] = dArr3[i7] - (this.sum[this.cl] * Math.exp((this.logClazz[i6] + ((DifferentiableStatisticalModel) this.score[0][i6]).getLogNormalizationConstant()) - d));
                }
                for (int i8 = this.shortcut[i6]; i8 < this.shortcut[i6 + 1]; i8++) {
                    double[] dArr4 = this.llGrad[0];
                    int i9 = i8;
                    dArr4[i9] = dArr4[i9] - (this.sum[this.cl] * Math.exp((this.logClazz[i6] + ((DifferentiableStatisticalModel) this.score[0][i6]).getLogPartialNormalizationConstant(i8 - this.shortcut[i6])) - d));
                }
            }
        }
        Arrays.fill(this.prGrad, 0.0d);
        if (this.beta[2] != 0.0d) {
            this.prior.addGradientFor(this.params, this.prGrad);
        }
        if (this.beta[1] == 0.0d) {
            Arrays.fill(this.llGrad[0], 0.0d);
        }
        if (this.beta[0] == 0.0d) {
            Arrays.fill(this.cllGrad[0], 0.0d);
        }
        double[] dArr5 = new double[getDimensionOfScope()];
        double d2 = this.norm ? this.sum[this.cl] : 1.0d;
        for (int i10 = 0; i10 < dArr5.length; i10++) {
            dArr5[i10] = (((this.beta[1] * this.llGrad[0][i10]) + (this.beta[0] * this.cllGrad[0][i10])) + (this.beta[2] * this.prGrad[i10])) / d2;
        }
        return dArr5;
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected void evaluateGradientOfFunction(int i, int i2, int i3, int i4, int i5) {
        Arrays.fill(this.llGrad[i], 0.0d);
        Arrays.fill(this.cllGrad[i], 0.0d);
        int i6 = i2;
        while (i6 <= i4) {
            int i7 = i6 == i2 ? i3 : 0;
            int numberOfElements = i6 == i4 ? i5 : this.data[i6].getNumberOfElements();
            for (int i8 = i7; i8 < numberOfElements; i8++) {
                Sequence elementAt = this.data[i6].getElementAt(i8);
                double d = this.weights[i6][i8];
                if (this.beta[0] != 0.0d) {
                    for (int i9 = 0; i9 < this.cl; i9++) {
                        this.iList[i][i9].clear();
                        this.dList[i][i9].clear();
                        this.helpArray[i][i9] = this.logClazz[i9] + this.score[i][i9].getLogScoreAndPartialDerivation(elementAt, 0, this.iList[i][i9], this.dList[i][i9]);
                    }
                } else {
                    this.iList[i][i6].clear();
                    this.dList[i][i6].clear();
                    this.helpArray[i][i6] = this.logClazz[i6] + this.score[i][i6].getLogScoreAndPartialDerivation(elementAt, 0, this.iList[i][i6], this.dList[i][i6]);
                }
                if (this.beta[1] != 0.0d) {
                    if (i6 < this.shortcut[0]) {
                        double[] dArr = this.llGrad[i];
                        int i10 = i6;
                        dArr[i10] = dArr[i10] + d;
                    }
                    for (int i11 = 0; i11 < this.iList[i][i6].length(); i11++) {
                        double[] dArr2 = this.llGrad[i];
                        int i12 = this.shortcut[i6] + this.iList[i][i6].get(i11);
                        dArr2[i12] = dArr2[i12] + (d * this.dList[i][i6].get(i11));
                    }
                }
                if (this.beta[0] != 0.0d) {
                    Normalisation.logSumNormalisation(this.helpArray[i], 0, this.helpArray[i].length, this.helpArray[i], 0);
                    for (int i13 = 0; i13 < this.shortcut[0]; i13++) {
                        if (i13 != i6) {
                            double[] dArr3 = this.cllGrad[i];
                            int i14 = i13;
                            dArr3[i14] = dArr3[i14] - (d * this.helpArray[i][i13]);
                        } else {
                            double[] dArr4 = this.cllGrad[i];
                            int i15 = i13;
                            dArr4[i15] = dArr4[i15] + (d * (1.0d - this.helpArray[i][i13]));
                        }
                    }
                    for (int i16 = 0; i16 < this.cl; i16++) {
                        if (i16 != i6) {
                            for (int i17 = 0; i17 < this.iList[i][i16].length(); i17++) {
                                double[] dArr5 = this.cllGrad[i];
                                int i18 = this.shortcut[i16] + this.iList[i][i16].get(i17);
                                dArr5[i18] = dArr5[i18] - ((d * this.dList[i][i16].get(i17)) * this.helpArray[i][i16]);
                            }
                        } else {
                            for (int i19 = 0; i19 < this.iList[i][i16].length(); i19++) {
                                double[] dArr6 = this.cllGrad[i];
                                int i20 = this.shortcut[i16] + this.iList[i][i16].get(i19);
                                dArr6[i20] = dArr6[i20] + (d * this.dList[i][i16].get(i19) * (1.0d - this.helpArray[i][i16]));
                            }
                        }
                    }
                }
            }
            i6++;
        }
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected double joinFunction() throws DimensionException, EvaluationException {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.helpArray.length; i++) {
            d2 += this.helpArray[i][0];
            d += this.helpArray[i][1];
        }
        if (this.beta[1] != 0.0d) {
            for (int i2 = 0; i2 < this.cl; i2++) {
                d3 = Normalisation.getLogSum(d3, this.logClazz[i2] + ((DifferentiableStatisticalModel) this.score[0][i2]).getLogNormalizationConstant());
            }
            d2 -= this.sum[this.cl] * d3;
        }
        double evaluateFunction = this.beta[2] != 0.0d ? this.prior.evaluateFunction(this.params) : 0.0d;
        if (this.beta[1] == 0.0d) {
            d2 = 0.0d;
        }
        if (this.beta[0] == 0.0d) {
            d = 0.0d;
        }
        double d4 = (this.beta[1] * d2) + (this.beta[0] * d) + (this.beta[2] * evaluateFunction);
        if (!Double.isNaN(d4) && !Double.isInfinite(d4)) {
            return this.norm ? d4 / this.sum[this.cl] : d4;
        }
        System.out.println(String.valueOf(d4) + "\t= " + this.beta[0] + " * " + d + " + " + this.beta[1] + " * " + d2 + " + " + this.beta[2] + " * " + evaluateFunction);
        System.out.println("params " + Arrays.toString(this.params));
        for (int i3 = 0; i3 < this.helpArray.length; i3++) {
            System.out.println(Arrays.toString(this.helpArray[i3]));
        }
        System.out.flush();
        throw new EvaluationException("Evaluating the function gives: " + this.beta[0] + " * " + d + " + " + this.beta[1] + " * " + d2 + " + " + this.beta[2] + " * " + evaluateFunction);
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected void evaluateFunction(int i, int i2, int i3, int i4, int i5) throws EvaluationException {
        double d = 0.0d;
        double d2 = 0.0d;
        int i6 = i2;
        while (i6 <= i4) {
            int i7 = i6 == i2 ? i3 : 0;
            int numberOfElements = i6 == i4 ? i5 : this.data[i6].getNumberOfElements();
            for (int i8 = i7; i8 < numberOfElements; i8++) {
                Sequence elementAt = this.data[i6].getElementAt(i8);
                if (this.beta[0] != 0.0d) {
                    for (int i9 = 0; i9 < this.cl; i9++) {
                        this.helpArray[i][i9] = this.logClazz[i9] + this.score[i][i9].getLogScoreFor(elementAt, 0);
                    }
                    d += this.weights[i6][i8] * (this.helpArray[i][i6] - Normalisation.getLogSum(this.helpArray[i]));
                } else {
                    this.helpArray[i][i6] = this.logClazz[i6] + this.score[i][i6].getLogScoreFor(elementAt, 0);
                }
                d2 += this.weights[i6][i8] * this.helpArray[i][i6];
            }
            i6++;
        }
        this.helpArray[i][0] = d2;
        this.helpArray[i][1] = d;
    }

    private void check() throws IllegalArgumentException {
        if (this.beta[1] != 0.0d) {
            for (int i = 0; i < this.score[0].length; i++) {
                if (!(this.score[0][i] instanceof DifferentiableStatisticalModel)) {
                    throw new IllegalArgumentException("For evaluating the likelihood we the ");
                }
            }
        }
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.DiffSSBasedOptimizableFunction
    public void reset(DifferentiableSequenceScore[] differentiableSequenceScoreArr) throws Exception {
        for (int i = 0; i < this.cl; i++) {
            this.score[0][i] = differentiableSequenceScoreArr[i];
            for (int i2 = 1; i2 < this.score.length; i2++) {
                this.score[i2][i] = this.score[0][i].mo114clone();
            }
            this.shortcut[i + 1] = this.shortcut[i] + this.score[0][i].getNumberOfParameters();
        }
        check();
        if (this.beta[2] > 0.0d && this.prior != null) {
            this.prior.set(this.freeParams, this.score[0]);
        }
        this.llGrad = new double[getNumberOfThreads()][getDimensionOfScope()];
        this.cllGrad = new double[this.llGrad.length][getDimensionOfScope()];
        this.prGrad = new double[getDimensionOfScope()];
    }
}
