/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifier.scoringFunctionBased.cll;

import de.jstacs.WrongAlphabetException;
import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifier.scoringFunctionBased.AbstractOptimizableFunction;
import de.jstacs.classifier.scoringFunctionBased.OptimizableFunction;
import de.jstacs.classifier.scoringFunctionBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifier.scoringFunctionBased.logPrior.LogPrior;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.scoringFunctions.ScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import java.util.Arrays;

public class NormConditionalLogLikelihood
extends AbstractOptimizableFunction {
    private ScoringFunction[] score;
    private double[] helpArray;
    private DoubleList[] dList;
    private IntList[] iList;
    private LogPrior prior;
    private static final double EPS = 1.0E-6;

    public NormConditionalLogLikelihood(ScoringFunction[] score, Sample[] data, double[][] weights, boolean norm, boolean freeParams) throws IllegalArgumentException, WrongAlphabetException {
        this(score, data, weights, null, norm, freeParams);
    }

    public NormConditionalLogLikelihood(ScoringFunction[] score, Sample[] data, double[][] weights, LogPrior prior, boolean norm, boolean freeParams) throws IllegalArgumentException, WrongAlphabetException {
        super(data, weights, norm, freeParams);
        if (this.cl < 2) {
            throw new IllegalArgumentException("The number of classes is not correct. Has to be at least 2.");
        }
        if (this.cl < score.length) {
            throw new IllegalArgumentException("The number of classes is not correct. Check the length of the score array.");
        }
        this.prior = prior == null ? DoesNothingLogPrior.defaultInstance : prior;
        this.helpArray = new double[this.cl];
        this.dList = new DoubleList[this.cl];
        this.iList = new IntList[this.cl];
        this.score = score;
        for (int i = 0; i < this.cl; ++i) {
            this.dList[i] = new DoubleList();
            this.iList[i] = new IntList();
        }
    }

    public double[] evaluateGradientOfFunction(double[] x) throws DimensionException, EvaluationException {
        int counter1;
        this.setParams(x);
        double[] grad = new double[this.shortcut[this.cl]];
        int counter4 = 0;
        for (int counter3 = 0; counter3 < this.cl; ++counter3) {
            for (int counter2 = 0; counter2 < this.data[counter3].getNumberOfElements(); ++counter2) {
                Sequence s = this.data[counter3].getElementAt(counter2);
                double weight = this.weights[counter3][counter2];
                for (counter1 = 0; counter1 < this.cl; ++counter1) {
                    this.iList[counter1].clear();
                    this.dList[counter1].clear();
                    this.helpArray[counter1] = this.logClazz[counter1] + this.score[counter1].getLogScoreAndPartialDerivation(s, 0, this.iList[counter1], this.dList[counter1]);
                }
                Normalisation.logSumNormalisation(this.helpArray, 0, this.helpArray.length, this.helpArray, 0);
                for (counter1 = 0; counter1 < this.shortcut[0]; ++counter1) {
                    if (counter1 != counter3) {
                        int n = counter1;
                        grad[n] = grad[n] - weight * this.helpArray[counter1];
                        continue;
                    }
                    int n = counter1;
                    grad[n] = grad[n] + weight * (1.0 - this.helpArray[counter1]);
                }
                for (counter1 = 0; counter1 < this.cl; ++counter1) {
                    if (counter1 != counter3) {
                        for (counter4 = 0; counter4 < this.iList[counter1].length(); ++counter4) {
                            int n = this.shortcut[counter1] + this.iList[counter1].get(counter4);
                            grad[n] = grad[n] - weight * this.dList[counter1].get(counter4) * this.helpArray[counter1];
                        }
                        continue;
                    }
                    for (counter4 = 0; counter4 < this.iList[counter1].length(); ++counter4) {
                        int n = this.shortcut[counter1] + this.iList[counter1].get(counter4);
                        grad[n] = grad[n] + weight * this.dList[counter1].get(counter4) * (1.0 - this.helpArray[counter1]);
                    }
                }
            }
        }
        this.prior.addGradientFor(x, grad);
        if (this.norm) {
            counter1 = 0;
            while (counter1 < grad.length) {
                int n = counter1++;
                grad[n] = grad[n] / this.sum[this.cl];
            }
        }
        return grad;
    }

    public double evaluateFunction(double[] x) throws DimensionException, EvaluationException {
        this.setParams(x);
        double cll = 0.0;
        for (int counter3 = 0; counter3 < this.cl; ++counter3) {
            for (int counter2 = 0; counter2 < this.data[counter3].getNumberOfElements(); ++counter2) {
                Sequence s = this.data[counter3].getElementAt(counter2);
                for (int counter1 = 0; counter1 < this.cl; ++counter1) {
                    this.helpArray[counter1] = this.logClazz[counter1] + this.score[counter1].getLogScore(s, 0);
                }
                cll += this.weights[counter3][counter2] * (this.helpArray[counter3] - Normalisation.getLogSum(this.helpArray));
            }
        }
        double pr = this.prior.evaluateFunction(x);
        if (Double.isNaN(cll + pr)) {
            System.out.println("params " + Arrays.toString(x));
            System.out.flush();
            throw new EvaluationException("Evaluating the function gives: " + cll + " + " + pr);
        }
        if (this.norm) {
            return (cll + pr) / this.sum[this.cl];
        }
        return cll + pr;
    }

    public void getParameters(OptimizableFunction.KindOfParameter kind, double[] erg) throws Exception {
        int i;
        double discount = 0.0;
        switch (kind) {
            case PLUGIN: {
                discount = Math.log(this.freeParams ? this.sum[this.cl - 1] + 1.0E-6 : this.sum[this.cl] + (double)this.cl * 1.0E-6);
                for (i = 0; i < this.shortcut[0]; ++i) {
                    erg[i] = Math.log(this.sum[i] + 1.0E-6) - discount;
                }
            }
            case LAST: {
                for (i = 0; i < this.shortcut[0]; ++i) {
                    erg[i] = this.logClazz[i];
                }
            }
            case ZEROS: {
                break;
            }
            default: {
                throw new IllegalArgumentException("Unknown kind of parameter");
            }
        }
        for (i = 0; i < this.cl; ++i) {
            System.arraycopy(this.score[i].getCurrentParameterValues(), 0, erg, this.shortcut[i], this.score[i].getNumberOfParameters());
        }
    }

    public void setParams(double[] params) throws DimensionException {
        super.setParams(params);
        for (int counter1 = 0; counter1 < this.cl; ++counter1) {
            this.score[counter1].setParameters(params, this.shortcut[counter1]);
        }
    }

    public int getNumberOfStarts() {
        return this.getNumberOfStarts(this.score);
    }

    public void reset(ScoringFunction[] funs) throws Exception {
        if (funs.length != this.cl) {
            throw new IllegalArgumentException("Could not reset.");
        }
        for (int i = 0; i < this.cl; ++i) {
            this.score[i] = funs[i];
            this.shortcut[i + 1] = this.shortcut[i] + this.score[i].getNumberOfParameters();
        }
        if (this.prior != null) {
            this.prior.set(this.freeParams, this.score);
        }
    }
}

