package de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix;

import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.utils.Normalisation;
import java.util.Arrays;

/* loaded from: input_file:de/jstacs/classifiers/differentiableSequenceScoreBased/gendismix/OneDataSetLogGenDisMixFunction.class */
public class OneDataSetLogGenDisMixFunction extends LogGenDisMixFunction {
    public OneDataSetLogGenDisMixFunction(int i, DifferentiableSequenceScore[] differentiableSequenceScoreArr, DataSet dataSet, double[][] dArr, LogPrior logPrior, double[] dArr2, boolean z, boolean z2) throws IllegalArgumentException {
        super(i, differentiableSequenceScoreArr, new DataSet[]{dataSet}, dArr, logPrior, dArr2, z, z2);
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction, de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractOptimizableFunction, de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction
    public void setDataAndWeights(DataSet[] dataSetArr, double[][] dArr) throws IllegalArgumentException {
        if (dataSetArr.length != 1 || dArr == null || dArr.length != this.cl) {
            throw new IllegalArgumentException("The dimension of the data set or weights (array) is not correct.");
        }
        this.data = dataSetArr;
        this.weights = dArr;
        this.sum[this.cl] = 0.0d;
        for (int i = 0; i < this.cl; i++) {
            this.sum[i] = 0.0d;
            if (dataSetArr[0].getNumberOfElements() != dArr[i].length) {
                throw new IllegalArgumentException("The dimension of the " + i + "-th weights (array) is not correct.");
            }
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                double[] dArr2 = this.sum;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + dArr[i][i2];
            }
            double[] dArr3 = this.sum;
            int i4 = this.cl;
            dArr3[i4] = dArr3[i4] + this.sum[i];
        }
        if (this.worker != null) {
            prepareThreads();
        }
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractOptimizableFunction, de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction
    public DataSet[] getData() {
        DataSet[] dataSetArr = new DataSet[this.weights.length];
        Arrays.fill(dataSetArr, this.data[0]);
        return dataSetArr;
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LogGenDisMixFunction, de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected void evaluateGradientOfFunction(int i, int i2, int i3, int i4, int i5) {
        Arrays.fill(this.llGrad[i], 0.0d);
        Arrays.fill(this.cllGrad[i], 0.0d);
        for (int i6 = i3; i6 < i5; i6++) {
            Sequence elementAt = this.data[0].getElementAt(i6);
            for (int i7 = 0; i7 < this.cl; i7++) {
                this.iList[i][i7].clear();
                this.dList[i][i7].clear();
                this.helpArray[i][i7] = this.logClazz[i7] + this.score[i][i7].getLogScoreAndPartialDerivation(elementAt, 0, this.iList[i][i7], this.dList[i][i7]);
            }
            Normalisation.logSumNormalisation(this.helpArray[i], 0, this.helpArray[i].length, this.helpArray[i], 0);
            for (int i8 = 0; i8 < this.cl; i8++) {
                double d = this.weights[i8][i6];
                if (this.beta[1] != 0.0d) {
                    if (i8 < this.shortcut[0]) {
                        double[] dArr = this.llGrad[i];
                        int i9 = i8;
                        dArr[i9] = dArr[i9] + d;
                    }
                    for (int i10 = 0; i10 < this.iList[i][i8].length(); i10++) {
                        double[] dArr2 = this.llGrad[i];
                        int i11 = this.shortcut[i8] + this.iList[i][i8].get(i10);
                        dArr2[i11] = dArr2[i11] + (d * this.dList[i][i8].get(i10));
                    }
                }
                if (this.beta[0] != 0.0d) {
                    for (int i12 = 0; i12 < this.shortcut[0]; i12++) {
                        if (i12 != i8) {
                            double[] dArr3 = this.cllGrad[i];
                            int i13 = i12;
                            dArr3[i13] = dArr3[i13] - (d * this.helpArray[i][i12]);
                        } else {
                            double[] dArr4 = this.cllGrad[i];
                            int i14 = i12;
                            dArr4[i14] = dArr4[i14] + (d * (1.0d - this.helpArray[i][i12]));
                        }
                    }
                    for (int i15 = 0; i15 < this.cl; i15++) {
                        if (i15 != i8) {
                            for (int i16 = 0; i16 < this.iList[i][i15].length(); i16++) {
                                double[] dArr5 = this.cllGrad[i];
                                int i17 = this.shortcut[i15] + this.iList[i][i15].get(i16);
                                dArr5[i17] = dArr5[i17] - ((d * this.dList[i][i15].get(i16)) * this.helpArray[i][i15]);
                            }
                        } else {
                            for (int i18 = 0; i18 < this.iList[i][i15].length(); i18++) {
                                double[] dArr6 = this.cllGrad[i];
                                int i19 = this.shortcut[i15] + this.iList[i][i15].get(i18);
                                dArr6[i19] = dArr6[i19] + (d * this.dList[i][i15].get(i18) * (1.0d - this.helpArray[i][i15]));
                            }
                        }
                    }
                }
            }
        }
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LogGenDisMixFunction, de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected void evaluateFunction(int i, int i2, int i3, int i4, int i5) throws EvaluationException {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i6 = i3; i6 < i5; i6++) {
            Sequence elementAt = this.data[0].getElementAt(i6);
            for (int i7 = 0; i7 < this.cl; i7++) {
                this.helpArray[i][i7] = this.logClazz[i7] + this.score[i][i7].getLogScoreFor(elementAt, 0);
            }
            if (this.beta[0] != 0.0d) {
                d3 = Normalisation.getLogSum(this.helpArray[i]);
            }
            for (int i8 = 0; i8 < this.cl; i8++) {
                d += this.weights[i8][i6] * (this.helpArray[i][i8] - d3);
                d2 += this.weights[i8][i6] * this.helpArray[i][i8];
            }
        }
        this.helpArray[i][0] = d2;
        this.helpArray[i][1] = d;
    }
}
