/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.scoringFunctions.homogeneous;

import de.jstacs.NonParsableException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.XMLParser;
import de.jstacs.models.discrete.inhomogeneous.MEMConstraint;
import de.jstacs.scoringFunctions.homogeneous.HomogeneousScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.FastDirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.util.Arrays;

public class HMM0ScoringFunction
extends HomogeneousScoringFunction {
    private double ess;
    private double norm;
    private double sumOfHyperParams;
    private double logGammaSum;
    private int[] counter;
    private boolean freeParams;
    private boolean plugIn;
    private boolean optimize;
    private MEMConstraint params;
    private int anz;

    public HMM0ScoringFunction(AlphabetContainer alphabets, int length, double ess, boolean plugIn, boolean optimize) {
        super(alphabets, length);
        if (ess < 0.0) {
            throw new IllegalArgumentException("The ess has to be non-negative.");
        }
        this.ess = ess;
        this.sumOfHyperParams = ess * (double)length;
        this.params = new MEMConstraint(new int[]{0}, new int[]{(int)alphabets.getAlphabetLengthAt(0)});
        this.plugIn = plugIn;
        this.optimize = optimize;
        this.setFreeParams(false);
        this.norm = 1.0;
        double d = -Math.log(alphabets.getAlphabetLengthAt(0));
        for (int i = 0; i < this.counter.length; ++i) {
            this.params.setLambda(i, d);
        }
        this.computeConstantsOfLogPrior();
    }

    public HMM0ScoringFunction(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    public HMM0ScoringFunction clone() throws CloneNotSupportedException {
        HMM0ScoringFunction clone = (HMM0ScoringFunction)super.clone();
        clone.params = this.params.clone();
        clone.counter = (int[])this.counter.clone();
        return clone;
    }

    public String getInstanceName() {
        return "hMM(0)";
    }

    public double getLogScore(Sequence seq, int start, int length) {
        double erg = 0.0;
        for (int l = 0; l < length; ++l) {
            erg += this.params.getLambda(this.params.satisfiesSpecificConstraint(seq, start + l));
        }
        return erg;
    }

    public double getLogScoreAndPartialDerivation(Sequence seq, int start, int length, IntList indices, DoubleList dList) {
        int l;
        Arrays.fill(this.counter, 0);
        for (l = 0; l < length; ++l) {
            int n = this.params.satisfiesSpecificConstraint(seq, start + l);
            this.counter[n] = this.counter[n] + 1;
        }
        double erg = 0.0;
        for (l = 0; l < this.counter.length; ++l) {
            if (this.counter[l] <= 0) continue;
            erg += (double)this.counter[l] * this.params.getLambda(l);
            if (l >= this.anz) continue;
            indices.add(l);
            dList.add(this.counter[l]);
        }
        return erg;
    }

    public int getNumberOfParameters() {
        return this.anz;
    }

    public void setParameters(double[] params, int start) {
        if (this.optimize) {
            this.norm = 0.0;
            for (int i = 0; i < this.anz; ++i) {
                this.params.setLambda(i, params[start + i]);
                this.norm += this.params.getExpLambda(i);
            }
            if (this.anz < this.counter.length) {
                this.norm += this.params.getExpLambda(this.anz);
            }
        }
    }

    public StringBuffer toXML() {
        StringBuffer b = new StringBuffer(1000);
        XMLParser.appendIntWithTags(b, this.length, "length");
        XMLParser.appendStorableWithTags(b, this.alphabets, "alphabets");
        XMLParser.appendDoubleWithTags(b, this.ess, "ess");
        XMLParser.appendDoubleWithTags(b, this.sumOfHyperParams, "sumOfHyperParams");
        XMLParser.appendStorableWithTags(b, this.params, "params");
        XMLParser.appendBooleanWithTags(b, this.freeParams, "freeParams");
        XMLParser.appendBooleanWithTags(b, this.plugIn, "plugIn");
        XMLParser.appendBooleanWithTags(b, this.optimize, "optimize");
        XMLParser.addTags(b, this.getClass().getSimpleName());
        return b;
    }

    public double[] getCurrentParameterValues() {
        double[] erg = new double[this.anz];
        for (int i = 0; i < this.anz; ++i) {
            erg[i] = this.params.getLambda(i);
        }
        return erg;
    }

    public double[][][] getAllConditionalStationaryDistributions() {
        double[][][] erg = new double[1][1][this.params.getNumberOfSpecificConstraints()];
        double norm = this.getNormalizationConstant(1);
        for (int i = 0; i < erg.length; ++i) {
            erg[0][0][i] = this.params.getExpLambda(i) / norm;
        }
        return erg;
    }

    public void initializeFunction(int index, boolean freeParams, Sample[] data, double[][] weights) {
        this.params.reset();
        if (this.plugIn) {
            if (data != null && data[index] != null) {
                for (int i = 0; i < data[index].getNumberOfElements(); ++i) {
                    Sequence seq = data[index].getElementAt(i);
                    int l = seq.getLength();
                    for (int k = 0; k < l; ++k) {
                        this.params.add(seq.discreteVal(k), weights[index][i]);
                    }
                }
            }
            this.params.estimate(this.sumOfHyperParams);
            for (int i = 0; i < this.counter.length; ++i) {
                this.params.setExpLambda(i, this.params.getFreq(i));
            }
        } else {
            double d = -Math.log(this.alphabets.getAlphabetLengthAt(0));
            for (int i = 0; i < this.counter.length; ++i) {
                this.params.setLambda(i, d);
            }
        }
        this.norm = 1.0;
        this.setFreeParams(freeParams);
    }

    public void initializeFunctionRandomly(boolean freeParams) {
        int n = this.counter.length;
        double[] p = DirichletMRG.DEFAULT_INSTANCE.generate(n, new FastDirichletMRGParams(this.sumOfHyperParams == 0.0 ? 1.0 : this.sumOfHyperParams / (double)n));
        for (int i = 0; i < n; ++i) {
            this.params.setExpLambda(i, p[i]);
        }
        this.norm = 1.0;
        this.setFreeParams(freeParams);
    }

    protected void fromXML(StringBuffer xml) throws NonParsableException {
        StringBuffer b = XMLParser.extractForTag(xml, this.getClass().getSimpleName());
        this.length = XMLParser.extractIntForTag(b, "length");
        this.alphabets = (AlphabetContainer)XMLParser.extractStorableForTag(b, "alphabets");
        this.ess = XMLParser.extractDoubleForTag(b, "ess");
        this.sumOfHyperParams = XMLParser.extractDoubleForTag(b, "sumOfHyperParams");
        this.params = (MEMConstraint)XMLParser.extractStorableForTag(b, "params");
        this.plugIn = XMLParser.extractBooleanForTag(b, "plugIn");
        this.optimize = XMLParser.extractBooleanForTag(b, "optimize");
        this.setFreeParams(XMLParser.extractBooleanForTag(b, "freeParams"));
        int i = 0;
        while (i < this.params.getNumberOfSpecificConstraints()) {
            this.norm += this.params.getExpLambda(i++);
        }
        this.computeConstantsOfLogPrior();
    }

    private void setFreeParams(boolean freeParams) {
        this.freeParams = freeParams;
        this.counter = new int[this.params.getNumberOfSpecificConstraints()];
        this.anz = this.optimize ? this.counter.length - (freeParams ? 1 : 0) : 0;
        if (freeParams) {
            for (int i = 0; i < this.params.getNumberOfSpecificConstraints(); ++i) {
                this.params.setLambda(i, this.params.getLambda(i) - this.params.getLambda(this.params.getNumberOfSpecificConstraints() - 1));
            }
        }
    }

    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        if (index < this.anz) {
            return this.params.getNumberOfSpecificConstraints();
        }
        throw new IndexOutOfBoundsException();
    }

    public double getNormalizationConstant(int length) {
        if (length == 0) {
            throw new RuntimeException("The normalization constant can not be computed for length 0.");
        }
        return Math.pow(this.norm, length);
    }

    public double getPartialNormalizationConstant(int parameterIndex, int length) throws Exception {
        if (parameterIndex < this.anz) {
            double erg = (double)length * Math.pow(this.norm, length - 1) * this.params.getExpLambda(parameterIndex);
            return erg;
        }
        throw new IndexOutOfBoundsException();
    }

    public double getEss() {
        return this.ess;
    }

    public String toString() {
        StringBuffer info = new StringBuffer(100);
        info.append(this.alphabets.getSymbol(0, 0.0) + ": " + this.params.getExpLambda(0) / this.norm);
        for (int i = 1; i < this.params.getNumberOfSpecificConstraints(); ++i) {
            info.append("\t" + this.alphabets.getSymbol(0, i) + ": " + this.params.getExpLambda(i) / this.norm);
        }
        return info.toString();
    }

    public double getLogPriorTerm() {
        if (this.optimize) {
            double val = 0.0;
            int n = this.params.getNumberOfSpecificConstraints();
            int i = 0;
            while (i < n) {
                val += this.params.getLambda(i++);
            }
            return val * this.sumOfHyperParams / (double)n + this.logGammaSum;
        }
        return 0.0;
    }

    private void computeConstantsOfLogPrior() {
        int anz = this.params.getNumberOfSpecificConstraints();
        this.logGammaSum = Gamma.logOfGamma((double)this.sumOfHyperParams) - (double)anz * Gamma.logOfGamma((double)(this.sumOfHyperParams / (double)anz));
    }

    public void addGradientOfLogPriorTerm(double[] grad, int start) {
        double d = this.sumOfHyperParams / (double)this.params.getNumberOfSpecificConstraints();
        for (int i = 0; i < this.anz; ++i) {
            int n = start + i;
            grad[n] = grad[n] + d;
        }
    }

    public boolean isInitialized() {
        return true;
    }

    public int getMaximalMarkovOrder() {
        return 0;
    }

    public void setStatisticForHyperparameters(int[] length, double[] weight) throws Exception {
        if (weight.length != length.length) {
            throw new IllegalArgumentException("The length of both arrays (length, weight) have to be identical.");
        }
        this.sumOfHyperParams = 0.0;
        for (int i = 0; i < length.length; ++i) {
            if (weight[i] < 0.0 || length[i] < 0) {
                throw new IllegalArgumentException("check length and weight for entry " + i);
            }
            this.sumOfHyperParams += (double)length[i] * weight[i];
        }
        this.computeConstantsOfLogPrior();
    }
}

