/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifier.scoringFunctionBased.gendismix;

import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifier.scoringFunctionBased.SFBasedOptimizableFunction;
import de.jstacs.classifier.scoringFunctionBased.gendismix.LearningPrinciple;
import de.jstacs.classifier.scoringFunctionBased.logPrior.LogPrior;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.scoringFunctions.ScoringFunction;
import de.jstacs.utils.Normalisation;
import java.util.Arrays;

public class LogGenDisMixFunction
extends SFBasedOptimizableFunction {
    private double[][] helpArray;
    private double[][] llGrad;
    private double[][] cllGrad;
    private double[] beta;
    private double[] prGrad;

    public LogGenDisMixFunction(int threads, ScoringFunction[] score, Sample[] data, double[][] weights, LogPrior prior, double[] beta, boolean norm, boolean freeParams) throws IllegalArgumentException {
        super(threads, score, data, weights, prior, norm, freeParams);
        if (this.cl < 1 || beta[0] != 0.0 && this.cl < 2) {
            throw new IllegalArgumentException("The number of classes is not correct. You can use this class for the generative trainiing of one class or the (in some kind) discriminative training of more than one class.");
        }
        this.beta = LearningPrinciple.checkWeights(beta);
        this.check();
        this.helpArray = new double[threads][Math.max(2, this.cl)];
    }

    @Override
    protected double[] joinGradients() throws EvaluationException {
        int index;
        for (index = 0; index < this.llGrad[0].length; ++index) {
            for (int t = 1; t < this.llGrad.length; ++t) {
                double[] dArray = this.llGrad[0];
                int n = index;
                dArray[n] = dArray[n] + this.llGrad[t][index];
                double[] dArray2 = this.cllGrad[0];
                int n2 = index;
                dArray2[n2] = dArray2[n2] + this.cllGrad[t][index];
            }
        }
        if (this.beta[1] != 0.0) {
            try {
                int counter3;
                double norm = Double.NEGATIVE_INFINITY;
                for (counter3 = 0; counter3 < this.cl; ++counter3) {
                    norm = Normalisation.getLogSum(norm, this.logClazz[counter3] + ((NormalizableScoringFunction)this.score[0][counter3]).getLogNormalizationConstant());
                }
                for (counter3 = 0; counter3 < this.cl; ++counter3) {
                    if (counter3 < this.shortcut[0]) {
                        double[] dArray = this.llGrad[0];
                        int n = counter3;
                        dArray[n] = dArray[n] - this.sum[this.cl] * Math.exp(this.logClazz[counter3] + ((NormalizableScoringFunction)this.score[0][counter3]).getLogNormalizationConstant() - norm);
                    }
                    for (int counter1 = this.shortcut[counter3]; counter1 < this.shortcut[counter3 + 1]; ++counter1) {
                        double[] dArray = this.llGrad[0];
                        int n = counter1;
                        dArray[n] = dArray[n] - this.sum[this.cl] * Math.exp(this.logClazz[counter3] + ((NormalizableScoringFunction)this.score[0][counter3]).getLogPartialNormalizationConstant(counter1 - this.shortcut[counter3]) - norm);
                    }
                }
            }
            catch (Exception e) {
                EvaluationException eva = new EvaluationException(e.getMessage());
                eva.setStackTrace(e.getStackTrace());
                throw eva;
            }
        }
        Arrays.fill(this.prGrad, 0.0);
        if (this.beta[2] != 0.0) {
            this.prior.addGradientFor(this.params, this.prGrad);
        }
        if (this.beta[1] == 0.0) {
            Arrays.fill(this.llGrad[0], 0.0);
        }
        if (this.beta[0] == 0.0) {
            Arrays.fill(this.cllGrad[0], 0.0);
        }
        double[] grad = new double[this.shortcut[this.cl]];
        double weight = this.norm ? this.sum[this.cl] : 1.0;
        for (index = 0; index < grad.length; ++index) {
            grad[index] = (this.beta[1] * this.llGrad[0][index] + this.beta[0] * this.cllGrad[0][index] + this.beta[2] * this.prGrad[index]) / weight;
        }
        return grad;
    }

    @Override
    protected void evaluateGradientOfFunction(int index, int startClass, int startSeq, int endClass, int endSeq) {
        Arrays.fill(this.llGrad[index], 0.0);
        Arrays.fill(this.cllGrad[index], 0.0);
        int counter4 = 0;
        for (int counter3 = startClass; counter3 <= endClass; ++counter3) {
            int start = counter3 == startClass ? startSeq : 0;
            int end = counter3 == endClass ? endSeq : this.data[counter3].getNumberOfElements();
            for (int counter2 = start; counter2 < end; ++counter2) {
                int counter1;
                Sequence s = this.data[counter3].getElementAt(counter2);
                double weight = this.weights[counter3][counter2];
                if (this.beta[0] != 0.0) {
                    for (counter1 = 0; counter1 < this.cl; ++counter1) {
                        this.iList[index][counter1].clear();
                        this.dList[index][counter1].clear();
                        this.helpArray[index][counter1] = this.logClazz[counter1] + this.score[index][counter1].getLogScoreAndPartialDerivation(s, 0, this.iList[index][counter1], this.dList[index][counter1]);
                    }
                } else {
                    this.iList[index][counter3].clear();
                    this.dList[index][counter3].clear();
                    this.helpArray[index][counter3] = this.logClazz[counter3] + this.score[index][counter3].getLogScoreAndPartialDerivation(s, 0, this.iList[index][counter3], this.dList[index][counter3]);
                }
                if (this.beta[1] != 0.0) {
                    if (counter3 < this.shortcut[0]) {
                        double[] dArray = this.llGrad[index];
                        int n = counter3;
                        dArray[n] = dArray[n] + weight;
                    }
                    for (counter4 = 0; counter4 < this.iList[index][counter3].length(); ++counter4) {
                        double[] dArray = this.llGrad[index];
                        int n = this.shortcut[counter3] + this.iList[index][counter3].get(counter4);
                        dArray[n] = dArray[n] + weight * this.dList[index][counter3].get(counter4);
                    }
                }
                if (this.beta[0] == 0.0) continue;
                Normalisation.logSumNormalisation(this.helpArray[index], 0, this.helpArray[index].length, this.helpArray[index], 0);
                for (counter1 = 0; counter1 < this.shortcut[0]; ++counter1) {
                    if (counter1 != counter3) {
                        double[] dArray = this.cllGrad[index];
                        int n = counter1;
                        dArray[n] = dArray[n] - weight * this.helpArray[index][counter1];
                        continue;
                    }
                    double[] dArray = this.cllGrad[index];
                    int n = counter1;
                    dArray[n] = dArray[n] + weight * (1.0 - this.helpArray[index][counter1]);
                }
                for (counter1 = 0; counter1 < this.cl; ++counter1) {
                    if (counter1 != counter3) {
                        for (counter4 = 0; counter4 < this.iList[index][counter1].length(); ++counter4) {
                            double[] dArray = this.cllGrad[index];
                            int n = this.shortcut[counter1] + this.iList[index][counter1].get(counter4);
                            dArray[n] = dArray[n] - weight * this.dList[index][counter1].get(counter4) * this.helpArray[index][counter1];
                        }
                        continue;
                    }
                    for (counter4 = 0; counter4 < this.iList[index][counter1].length(); ++counter4) {
                        double[] dArray = this.cllGrad[index];
                        int n = this.shortcut[counter1] + this.iList[index][counter1].get(counter4);
                        dArray[n] = dArray[n] + weight * this.dList[index][counter1].get(counter4) * (1.0 - this.helpArray[index][counter1]);
                    }
                }
            }
        }
    }

    @Override
    protected double joinFunction() throws DimensionException, EvaluationException {
        double res;
        int i;
        double cll = 0.0;
        double ll = 0.0;
        double lpr = 0.0;
        double z = Double.NEGATIVE_INFINITY;
        for (i = 0; i < this.helpArray.length; ++i) {
            ll += this.helpArray[i][0];
            cll += this.helpArray[i][1];
        }
        if (this.beta[1] != 0.0) {
            for (i = 0; i < this.cl; ++i) {
                z = Normalisation.getLogSum(z, this.logClazz[i] + ((NormalizableScoringFunction)this.score[0][i]).getLogNormalizationConstant());
            }
            ll -= this.sum[this.cl] * z;
        }
        if (this.beta[2] != 0.0) {
            lpr = this.prior.evaluateFunction(this.params);
        }
        if (this.beta[1] == 0.0) {
            ll = 0.0;
        }
        if (this.beta[0] == 0.0) {
            cll = 0.0;
        }
        if (Double.isNaN(res = this.beta[1] * ll + this.beta[0] * cll + this.beta[2] * lpr) || Double.isInfinite(res)) {
            System.out.println(res + "\t= " + this.beta[0] + " * " + cll + " + " + this.beta[1] + " * " + ll + " + " + this.beta[2] + " * " + lpr);
            System.out.println("params " + Arrays.toString(this.params));
            System.out.flush();
            throw new EvaluationException("Evaluating the function gives: " + this.beta[0] + " * " + cll + " + " + this.beta[1] + " * " + ll + " + " + this.beta[2] + " * " + lpr);
        }
        if (this.norm) {
            return res / this.sum[this.cl];
        }
        return res;
    }

    @Override
    protected void evaluateFunction(int index, int startClass, int startSeq, int endClass, int endSeq) throws EvaluationException {
        double cll = 0.0;
        double ll = 0.0;
        for (int counter3 = startClass; counter3 <= endClass; ++counter3) {
            int start = counter3 == startClass ? startSeq : 0;
            int end = counter3 == endClass ? endSeq : this.data[counter3].getNumberOfElements();
            for (int counter2 = start; counter2 < end; ++counter2) {
                Sequence s = this.data[counter3].getElementAt(counter2);
                if (this.beta[0] != 0.0) {
                    for (int counter1 = 0; counter1 < this.cl; ++counter1) {
                        this.helpArray[index][counter1] = this.logClazz[counter1] + this.score[index][counter1].getLogScore(s, 0);
                    }
                    cll += this.weights[counter3][counter2] * (this.helpArray[index][counter3] - Normalisation.getLogSum(this.helpArray[index]));
                } else {
                    this.helpArray[index][counter3] = this.logClazz[counter3] + this.score[index][counter3].getLogScore(s, 0);
                }
                ll += this.weights[counter3][counter2] * this.helpArray[index][counter3];
            }
        }
        this.helpArray[index][0] = ll;
        this.helpArray[index][1] = cll;
    }

    private void check() throws IllegalArgumentException {
        if (this.beta[1] != 0.0) {
            for (int i = 0; i < this.score[0].length; ++i) {
                if (this.score[0][i] instanceof NormalizableScoringFunction) continue;
                throw new IllegalArgumentException("For evaluating the likelihood we the ");
            }
        }
    }

    @Override
    public void reset(ScoringFunction[] funs) throws Exception {
        for (int i = 0; i < this.cl; ++i) {
            this.score[0][i] = funs[i];
            for (int j = 1; j < this.score.length; ++j) {
                this.score[j][i] = this.score[0][i].clone();
            }
            this.shortcut[i + 1] = this.shortcut[i] + this.score[0][i].getNumberOfParameters();
        }
        this.check();
        if (this.beta[2] > 0.0 && this.prior != null) {
            this.prior.set(this.freeParams, this.score[0]);
        }
        this.llGrad = new double[this.getNumberOfThreads()][this.shortcut[this.cl]];
        this.cllGrad = new double[this.llGrad.length][this.shortcut[this.cl]];
        this.prGrad = new double[this.shortcut[this.cl]];
    }
}

