/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix;

import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifiers.differentiableSequenceScoreBased.DiffSSBasedOptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.utils.Normalisation;
import java.util.Arrays;

public class LogGenDisMixFunction
extends DiffSSBasedOptimizableFunction {
    protected double[][] helpArray;
    protected double[][] llGrad;
    protected double[][] cllGrad;
    protected double[] beta;
    protected double[] prGrad;

    public LogGenDisMixFunction(int threads, DifferentiableSequenceScore[] score, DataSet[] data, double[][] weights, LogPrior prior, double[] beta, boolean norm, boolean freeParams) throws IllegalArgumentException {
        super(threads, score, data, weights, prior, norm, freeParams);
        if (this.cl < 1 || beta[0] != 0.0 && this.cl < 2) {
            throw new IllegalArgumentException("The number of classes is not correct. You can use this class for the generative training of one class or the (in some kind) discriminative training of more than one class.");
        }
        this.beta = LearningPrinciple.checkWeights(beta);
        this.check();
        this.helpArray = new double[threads][Math.max(2, this.cl)];
    }

    @Override
    protected double[] joinGradients() throws EvaluationException {
        int index = 0;
        while (index < this.llGrad[0].length) {
            int t = 1;
            while (t < this.llGrad.length) {
                double[] dArray = this.llGrad[0];
                int n = index;
                dArray[n] = dArray[n] + this.llGrad[t][index];
                double[] dArray2 = this.cllGrad[0];
                int n2 = index;
                dArray2[n2] = dArray2[n2] + this.cllGrad[t][index];
                ++t;
            }
            ++index;
        }
        if (this.beta[1] != 0.0) {
            try {
                double norm = Double.NEGATIVE_INFINITY;
                int counter3 = 0;
                while (counter3 < this.cl) {
                    norm = Normalisation.getLogSum(norm, this.logClazz[counter3] + ((DifferentiableStatisticalModel)this.score[0][counter3]).getLogNormalizationConstant());
                    ++counter3;
                }
                counter3 = 0;
                while (counter3 < this.cl) {
                    if (counter3 < this.shortcut[0]) {
                        double[] dArray = this.llGrad[0];
                        int n = counter3;
                        dArray[n] = dArray[n] - this.sum[this.cl] * Math.exp(this.logClazz[counter3] + ((DifferentiableStatisticalModel)this.score[0][counter3]).getLogNormalizationConstant() - norm);
                    }
                    int counter1 = this.shortcut[counter3];
                    while (counter1 < this.shortcut[counter3 + 1]) {
                        double[] dArray = this.llGrad[0];
                        int n = counter1;
                        dArray[n] = dArray[n] - this.sum[this.cl] * Math.exp(this.logClazz[counter3] + ((DifferentiableStatisticalModel)this.score[0][counter3]).getLogPartialNormalizationConstant(counter1 - this.shortcut[counter3]) - norm);
                        ++counter1;
                    }
                    ++counter3;
                }
            }
            catch (Exception e) {
                EvaluationException eva = new EvaluationException(e.getMessage());
                eva.setStackTrace(e.getStackTrace());
                throw eva;
            }
        }
        Arrays.fill(this.prGrad, 0.0);
        if (this.beta[2] != 0.0) {
            this.prior.addGradientFor(this.params, this.prGrad);
        }
        if (this.beta[1] == 0.0) {
            Arrays.fill(this.llGrad[0], 0.0);
        }
        if (this.beta[0] == 0.0) {
            Arrays.fill(this.cllGrad[0], 0.0);
        }
        double[] grad = new double[this.shortcut[this.cl]];
        double weight = this.norm ? this.sum[this.cl] : 1.0;
        index = 0;
        while (index < grad.length) {
            grad[index] = (this.beta[1] * this.llGrad[0][index] + this.beta[0] * this.cllGrad[0][index] + this.beta[2] * this.prGrad[index]) / weight;
            ++index;
        }
        return grad;
    }

    @Override
    protected void evaluateGradientOfFunction(int index, int startClass, int startSeq, int endClass, int endSeq) {
        Arrays.fill(this.llGrad[index], 0.0);
        Arrays.fill(this.cllGrad[index], 0.0);
        int counter3 = startClass;
        int counter4 = 0;
        while (counter3 <= endClass) {
            int start = counter3 == startClass ? startSeq : 0;
            int end = counter3 == endClass ? endSeq : this.data[counter3].getNumberOfElements();
            int counter2 = start;
            while (counter2 < end) {
                int counter1;
                Sequence s = this.data[counter3].getElementAt(counter2);
                double weight = this.weights[counter3][counter2];
                if (this.beta[0] != 0.0) {
                    counter1 = 0;
                    while (counter1 < this.cl) {
                        this.iList[index][counter1].clear();
                        this.dList[index][counter1].clear();
                        this.helpArray[index][counter1] = this.logClazz[counter1] + this.score[index][counter1].getLogScoreAndPartialDerivation(s, 0, this.iList[index][counter1], this.dList[index][counter1]);
                        ++counter1;
                    }
                } else {
                    this.iList[index][counter3].clear();
                    this.dList[index][counter3].clear();
                    this.helpArray[index][counter3] = this.logClazz[counter3] + this.score[index][counter3].getLogScoreAndPartialDerivation(s, 0, this.iList[index][counter3], this.dList[index][counter3]);
                }
                if (this.beta[1] != 0.0) {
                    if (counter3 < this.shortcut[0]) {
                        double[] dArray = this.llGrad[index];
                        int n = counter3;
                        dArray[n] = dArray[n] + weight;
                    }
                    counter4 = 0;
                    while (counter4 < this.iList[index][counter3].length()) {
                        double[] dArray = this.llGrad[index];
                        int n = this.shortcut[counter3] + this.iList[index][counter3].get(counter4);
                        dArray[n] = dArray[n] + weight * this.dList[index][counter3].get(counter4);
                        ++counter4;
                    }
                }
                if (this.beta[0] != 0.0) {
                    Normalisation.logSumNormalisation(this.helpArray[index], 0, this.helpArray[index].length, this.helpArray[index], 0);
                    counter1 = 0;
                    while (counter1 < this.shortcut[0]) {
                        if (counter1 != counter3) {
                            double[] dArray = this.cllGrad[index];
                            int n = counter1;
                            dArray[n] = dArray[n] - weight * this.helpArray[index][counter1];
                        } else {
                            double[] dArray = this.cllGrad[index];
                            int n = counter1;
                            dArray[n] = dArray[n] + weight * (1.0 - this.helpArray[index][counter1]);
                        }
                        ++counter1;
                    }
                    counter1 = 0;
                    while (counter1 < this.cl) {
                        if (counter1 != counter3) {
                            counter4 = 0;
                            while (counter4 < this.iList[index][counter1].length()) {
                                double[] dArray = this.cllGrad[index];
                                int n = this.shortcut[counter1] + this.iList[index][counter1].get(counter4);
                                dArray[n] = dArray[n] - weight * this.dList[index][counter1].get(counter4) * this.helpArray[index][counter1];
                                ++counter4;
                            }
                        } else {
                            counter4 = 0;
                            while (counter4 < this.iList[index][counter1].length()) {
                                double[] dArray = this.cllGrad[index];
                                int n = this.shortcut[counter1] + this.iList[index][counter1].get(counter4);
                                dArray[n] = dArray[n] + weight * this.dList[index][counter1].get(counter4) * (1.0 - this.helpArray[index][counter1]);
                                ++counter4;
                            }
                        }
                        ++counter1;
                    }
                }
                ++counter2;
            }
            ++counter3;
        }
    }

    @Override
    protected double joinFunction() throws DimensionException, EvaluationException {
        double res;
        int i = 0;
        double cll = 0.0;
        double ll = 0.0;
        double lpr = 0.0;
        double z = Double.NEGATIVE_INFINITY;
        while (i < this.helpArray.length) {
            ll += this.helpArray[i][0];
            cll += this.helpArray[i][1];
            ++i;
        }
        if (this.beta[1] != 0.0) {
            i = 0;
            while (i < this.cl) {
                z = Normalisation.getLogSum(z, this.logClazz[i] + ((DifferentiableStatisticalModel)this.score[0][i]).getLogNormalizationConstant());
                ++i;
            }
            ll -= this.sum[this.cl] * z;
        }
        if (this.beta[2] != 0.0) {
            lpr = this.prior.evaluateFunction(this.params);
        }
        if (this.beta[1] == 0.0) {
            ll = 0.0;
        }
        if (this.beta[0] == 0.0) {
            cll = 0.0;
        }
        if (Double.isNaN(res = this.beta[1] * ll + this.beta[0] * cll + this.beta[2] * lpr) || Double.isInfinite(res)) {
            System.err.println(String.valueOf(res) + "\t= " + this.beta[0] + " * " + cll + " + " + this.beta[1] + " * " + ll + " + " + this.beta[2] + " * " + lpr);
            System.err.println("params " + Arrays.toString(this.params));
            System.err.flush();
            throw new EvaluationException("Evaluating the function gives: " + this.beta[0] + " * " + cll + " + " + this.beta[1] + " * " + ll + " + " + this.beta[2] + " * " + lpr);
        }
        if (this.norm) {
            return res / this.sum[this.cl];
        }
        return res;
    }

    @Override
    protected void evaluateFunction(int index, int startClass, int startSeq, int endClass, int endSeq) throws EvaluationException {
        double cll = 0.0;
        double ll = 0.0;
        int counter3 = startClass;
        while (counter3 <= endClass) {
            int start = counter3 == startClass ? startSeq : 0;
            int end = counter3 == endClass ? endSeq : this.data[counter3].getNumberOfElements();
            int counter2 = start;
            while (counter2 < end) {
                Sequence s = this.data[counter3].getElementAt(counter2);
                if (this.beta[0] != 0.0) {
                    int counter1 = 0;
                    while (counter1 < this.cl) {
                        this.helpArray[index][counter1] = this.logClazz[counter1] + this.score[index][counter1].getLogScoreFor(s, 0);
                        ++counter1;
                    }
                    if (Double.isInfinite(cll += this.weights[counter3][counter2] * (this.helpArray[index][counter3] - Normalisation.getLogSum(this.helpArray[index])))) {
                        System.out.println("c3: " + counter3 + ", c2: " + counter2 + " idx:" + index);
                        System.out.println("logClazz: " + Arrays.toString(this.logClazz));
                        System.out.println("helpArray: " + Arrays.toString(this.helpArray[index]));
                        System.out.println("w: " + this.weights[counter3][counter2]);
                        System.out.println("seq: " + s);
                        System.err.println(String.valueOf(cll) + "\t= " + this.beta[0] + " * " + cll + " + " + this.beta[1] + " * " + ll);
                        System.err.println("params " + Arrays.toString(this.params));
                        System.out.println(Arrays.toString(this.score[index]));
                        System.out.flush();
                        throw new EvaluationException("Infinite");
                    }
                } else {
                    this.helpArray[index][counter3] = this.logClazz[counter3] + this.score[index][counter3].getLogScoreFor(s, 0);
                }
                ll += this.weights[counter3][counter2] * this.helpArray[index][counter3];
                ++counter2;
            }
            ++counter3;
        }
        this.helpArray[index][0] = ll;
        this.helpArray[index][1] = cll;
    }

    private void check() throws IllegalArgumentException {
        if (this.beta[1] != 0.0) {
            int i = 0;
            while (i < this.score[0].length) {
                if (!(this.score[0][i] instanceof DifferentiableStatisticalModel)) {
                    throw new IllegalArgumentException("For evaluating the likelihood we the ");
                }
                ++i;
            }
        }
    }

    @Override
    public void reset(DifferentiableSequenceScore[] funs) throws Exception {
        int i = 0;
        while (i < this.cl) {
            this.score[0][i] = funs[i];
            int j = 1;
            while (j < this.score.length) {
                this.score[j][i] = this.score[0][i].clone();
                ++j;
            }
            this.shortcut[i + 1] = this.shortcut[i] + this.score[0][i].getNumberOfParameters();
            ++i;
        }
        this.check();
        if (this.beta[2] > 0.0 && this.prior != null) {
            this.prior.set(this.freeParams, this.score[0]);
        }
        this.llGrad = new double[this.getNumberOfThreads()][this.shortcut[this.cl]];
        this.cllGrad = new double[this.llGrad.length][this.shortcut[this.cl]];
        this.prGrad = new double[this.shortcut[this.cl]];
    }
}

