/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior;

import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.utils.Normalisation;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;

public class CompositeLogPrior
extends LogPrior {
    private DifferentiableStatisticalModel[] function;
    private double fullEss;
    private double logGammaSum;
    private double[] ess;
    private double[] classPars;
    private double[] logPartNorm;
    private boolean freeParameters;

    public CompositeLogPrior() {
    }

    public CompositeLogPrior(StringBuffer xml) {
        this();
    }

    @Override
    public void set(boolean freeParameters, DifferentiableSequenceScore ... funs) throws Exception {
        this.function = new DifferentiableStatisticalModel[funs.length];
        this.ess = new double[funs.length];
        this.classPars = new double[funs.length];
        this.logPartNorm = new double[funs.length];
        this.fullEss = 0.0;
        this.logGammaSum = 0.0;
        for (int i = 0; i < funs.length; ++i) {
            if (!(funs[i] instanceof DifferentiableStatisticalModel)) {
                throw new Exception("Only DifferentiableStatisticalModel allowed.");
            }
            this.function[i] = (DifferentiableStatisticalModel)funs[i];
            this.ess[i] = this.function[i].getESS();
            if (this.ess[i] == 0.0) {
                throw new IllegalArgumentException("The ess of the function " + i + " is zero, but should be positive.");
            }
            this.fullEss += this.ess[i];
            this.logGammaSum -= Gamma.logOfGamma(this.ess[i]);
        }
        this.logGammaSum += Gamma.logOfGamma(this.fullEss);
        this.freeParameters = freeParameters;
    }

    @Override
    public void addGradientFor(double[] params, double[] grad) throws EvaluationException {
        try {
            int k;
            int start = 0;
            int j = this.function.length - (this.freeParameters ? 1 : 0);
            for (k = 0; k < j; ++k) {
                this.classPars[k] = params[k];
                this.logPartNorm[k] = this.classPars[k] + this.function[k].getLogNormalizationConstant();
            }
            if (this.freeParameters) {
                this.classPars[j] = 0.0;
                this.logPartNorm[j] = this.function[j].getLogNormalizationConstant();
            }
            double fullNorm = Normalisation.logSumNormalisation(this.logPartNorm);
            for (start = 0; start < j; ++start) {
                int n = start;
                grad[n] = grad[n] + (this.ess[start] - this.fullEss * this.logPartNorm[start]);
            }
            for (j = 0; j < this.function.length; ++j) {
                this.function[j].addGradientOfLogPriorTerm(grad, start);
                k = this.function[j].getNumberOfParameters();
                for (int l = 0; l < k; ++l) {
                    int n = start++;
                    grad[n] = grad[n] - this.fullEss * Math.exp(this.classPars[j] + this.function[j].getLogPartialNormalizationConstant(l) - fullNorm);
                }
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            throw new EvaluationException(e.getMessage());
        }
    }

    @Override
    public double evaluateFunction(double[] x) throws DimensionException, EvaluationException {
        try {
            double logNorm = Double.NEGATIVE_INFINITY;
            double logProductPart = 0.0;
            int j = this.function.length - (this.freeParameters ? 1 : 0);
            for (int i = 0; i < j; ++i) {
                logNorm = Normalisation.getLogSum(logNorm, x[i] + this.function[i].getLogNormalizationConstant());
                logProductPart += x[i] * this.ess[i] + this.function[i].getLogPriorTerm();
            }
            if (this.freeParameters) {
                logNorm = Normalisation.getLogSum(logNorm, this.function[j].getLogNormalizationConstant());
                logProductPart += this.function[j].getLogPriorTerm();
            }
            return this.logGammaSum - this.fullEss * logNorm + logProductPart;
        }
        catch (Exception e) {
            e.printStackTrace();
            EvaluationException eva = new EvaluationException(e.getCause().getMessage());
            eva.setStackTrace(e.getStackTrace());
            throw eva;
        }
    }

    @Override
    public int getDimensionOfScope() {
        int all = this.function.length - (this.freeParameters ? 1 : 0);
        for (int i = 0; i < this.function.length; ++i) {
            int current = this.function[i].getNumberOfParameters();
            if (current == -1) {
                return -1;
            }
            all += current;
        }
        return all;
    }

    @Override
    public CompositeLogPrior getNewInstance() throws CloneNotSupportedException {
        return new CompositeLogPrior();
    }

    @Override
    public StringBuffer toXML() {
        return new StringBuffer(1);
    }

    @Override
    public String getInstanceName() {
        return "Composite log prior";
    }
}

