/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifier.scoringFunctionBased.cll;

import de.jstacs.WrongAlphabetException;
import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifier.scoringFunctionBased.OptimizableFunction;
import de.jstacs.classifier.scoringFunctionBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifier.scoringFunctionBased.logPrior.LogPrior;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.scoringFunctions.ScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import java.util.Arrays;

public class NormConditionalLogLikelihood
extends OptimizableFunction {
    private ScoringFunction[] score;
    private int[] shortcut;
    private Sample[] data;
    private double[][] weights;
    private double[] helpArray;
    private double[] clazz;
    private double[] logClazz;
    private double[] sum;
    private DoubleList[] dList;
    private IntList[] iList;
    private int cl;
    private LogPrior prior;
    private boolean norm;
    private boolean freeParams;

    public NormConditionalLogLikelihood(ScoringFunction[] score, Sample[] data, double[][] weights, boolean norm, boolean freeParams) throws IllegalArgumentException, WrongAlphabetException {
        this(score, data, weights, null, norm, freeParams);
    }

    public NormConditionalLogLikelihood(ScoringFunction[] score, Sample[] data, double[][] weights, LogPrior prior, boolean norm, boolean freeParams) throws IllegalArgumentException, WrongAlphabetException {
        this.prior = prior == null ? DoesNothingLogPrior.defaultInstance : prior;
        this.norm = norm;
        this.freeParams = freeParams;
        this.shortcut = new int[score.length + 1];
        this.cl = score.length;
        if (this.cl < 2 || this.cl != data.length) {
            throw new IllegalArgumentException("The number of classes is not correct. Check the the length of the constraint array as well as the length of the array f.");
        }
        this.shortcut[0] = freeParams ? this.cl - 1 : this.cl;
        this.data = data;
        this.weights = weights;
        this.helpArray = new double[this.cl];
        this.logClazz = new double[this.cl];
        this.clazz = new double[this.cl];
        this.dList = new DoubleList[this.cl];
        this.iList = new IntList[this.cl];
        this.score = score;
        this.sum = new double[this.cl + 1];
        this.sum[this.cl] = 0.0;
        for (int i = 0; i < this.cl; ++i) {
            this.dList[i] = new DoubleList();
            this.iList[i] = new IntList();
            this.sum[i] = 0.0;
            for (int j = 0; j < weights[i].length; ++j) {
                int n = i;
                this.sum[n] = this.sum[n] + weights[i][j];
            }
            int n = this.cl;
            this.sum[n] = this.sum[n] + this.sum[i];
        }
    }

    public double[] evaluateGradientOfFunction(double[] x) throws DimensionException, EvaluationException {
        int counter1;
        this.setParams(x);
        double[] grad = new double[this.shortcut[this.cl]];
        int counter4 = 0;
        for (int counter3 = 0; counter3 < this.cl; ++counter3) {
            for (int counter2 = 0; counter2 < this.data[counter3].getNumberOfElements(); ++counter2) {
                Sequence s = this.data[counter3].getElementAt(counter2);
                double weight = this.weights[counter3][counter2];
                for (counter1 = 0; counter1 < this.cl; ++counter1) {
                    this.iList[counter1].clear();
                    this.dList[counter1].clear();
                    this.helpArray[counter1] = this.logClazz[counter1] + this.score[counter1].getLogScoreAndPartialDerivation(s, 0, this.iList[counter1], this.dList[counter1]);
                }
                Normalisation.logSumNormalisation(this.helpArray, 0, this.helpArray.length, this.helpArray, 0);
                for (counter1 = 0; counter1 < this.shortcut[0]; ++counter1) {
                    if (counter1 != counter3) {
                        int n = counter1;
                        grad[n] = grad[n] - weight * this.helpArray[counter1];
                        continue;
                    }
                    int n = counter1;
                    grad[n] = grad[n] + weight * (1.0 - this.helpArray[counter1]);
                }
                for (counter1 = 0; counter1 < this.cl; ++counter1) {
                    if (counter1 != counter3) {
                        for (counter4 = 0; counter4 < this.iList[counter1].length(); ++counter4) {
                            int n = this.shortcut[counter1] + this.iList[counter1].get(counter4);
                            grad[n] = grad[n] - weight * this.dList[counter1].get(counter4) * this.helpArray[counter1];
                        }
                        continue;
                    }
                    for (counter4 = 0; counter4 < this.iList[counter1].length(); ++counter4) {
                        int n = this.shortcut[counter1] + this.iList[counter1].get(counter4);
                        grad[n] = grad[n] + weight * this.dList[counter1].get(counter4) * (1.0 - this.helpArray[counter1]);
                    }
                }
            }
        }
        this.prior.addGradientFor(x, grad);
        if (this.norm) {
            counter1 = 0;
            while (counter1 < grad.length) {
                int n = counter1++;
                grad[n] = grad[n] / this.sum[this.cl];
            }
        }
        return grad;
    }

    public double evaluateFunction(double[] x) throws DimensionException, EvaluationException {
        this.setParams(x);
        double cll = 0.0;
        for (int counter3 = 0; counter3 < this.cl; ++counter3) {
            for (int counter2 = 0; counter2 < this.data[counter3].getNumberOfElements(); ++counter2) {
                Sequence s = this.data[counter3].getElementAt(counter2);
                for (int counter1 = 0; counter1 < this.cl; ++counter1) {
                    this.helpArray[counter1] = this.logClazz[counter1] + this.score[counter1].getLogScore(s, 0);
                }
                cll += this.weights[counter3][counter2] * (this.helpArray[counter3] - Normalisation.getLogSum(this.helpArray));
            }
        }
        double pr = this.prior.evaluateFunction(x);
        if (Double.isNaN(cll + pr)) {
            System.out.println("params " + Arrays.toString(x));
            System.out.flush();
            throw new EvaluationException("Evaluating the function gives: " + cll + " + " + pr);
        }
        if (this.norm) {
            return (cll + pr) / this.sum[this.cl];
        }
        return cll + pr;
    }

    public int getDimensionOfScope() {
        return this.shortcut[this.cl];
    }

    public void getStartParams(boolean plugIn, double[] erg) throws Exception {
        if (erg == null || erg.length != this.getDimensionOfScope()) {
            throw new Exception("Null argument or length do not match.");
        }
        if (plugIn) {
            double l = this.freeParams ? Math.log(this.sum[this.cl - 1]) : Math.log(this.sum[this.cl]);
            for (int i = 0; i < this.cl; ++i) {
                if (i < this.shortcut[0]) {
                    erg[i] = Math.log(this.sum[i]) - l;
                }
                System.arraycopy(this.score[i].getCurrentParameterValues(), 0, erg, this.shortcut[i], this.score[i].getNumberOfParameters());
            }
        }
    }

    public double[] getStartParams(boolean plugIn) throws Exception {
        double[] temp = new double[this.getDimensionOfScope()];
        this.getStartParams(plugIn, temp);
        return temp;
    }

    public void setParams(double[] params) throws DimensionException {
        if (params == null || params.length != this.shortcut[this.cl]) {
            if (params != null) {
                throw new DimensionException(params.length, this.shortcut[this.cl]);
            }
            throw new DimensionException(0, this.shortcut[this.cl]);
        }
        for (int counter1 = 0; counter1 < this.cl; ++counter1) {
            if (counter1 < this.shortcut[0] || !this.freeParams) {
                this.logClazz[counter1] = params[counter1];
                this.clazz[counter1] = Math.exp(params[counter1]);
            } else {
                this.logClazz[counter1] = 0.0;
                this.clazz[counter1] = 1.0;
            }
            this.score[counter1].setParameters(params, this.shortcut[counter1]);
        }
    }

    public double[] getClassParams(double[] params) {
        double[] res = new double[this.cl];
        System.arraycopy(params, 0, res, 0, this.shortcut[0]);
        if (this.freeParams) {
            res[this.shortcut[0]] = 0.0;
        }
        return res;
    }

    public int getNumberOfStarts() {
        int starts = this.score[0].getNumberOfRecommendedStarts();
        for (int i = 1; i < this.score.length; ++i) {
            starts = Math.max(starts, this.score[i].getNumberOfRecommendedStarts());
        }
        return starts;
    }

    public void reset(ScoringFunction[] funs) throws Exception {
        if (funs.length != this.cl) {
            throw new IllegalArgumentException("Could not reset.");
        }
        for (int i = 0; i < this.cl; ++i) {
            this.score[i] = funs[i];
            this.shortcut[i + 1] = this.shortcut[i] + this.score[i].getNumberOfParameters();
        }
        if (this.prior != null) {
            this.prior.set(this.freeParams, this.score);
        }
    }
}

