/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.HomogeneousDiffSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.MEMConstraint;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.FastDirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.util.Arrays;

public class HomogeneousMM0DiffSM
extends HomogeneousDiffSM {
    private double ess;
    private double norm;
    private double sumOfHyperParams;
    private double logGammaSum;
    private int[] counter;
    private boolean freeParams;
    private boolean plugIn;
    private boolean optimize;
    private MEMConstraint params;
    private int anz;

    public HomogeneousMM0DiffSM(AlphabetContainer alphabets, int length, double ess, boolean plugIn, boolean optimize) {
        super(alphabets, length);
        if (ess < 0.0) {
            throw new IllegalArgumentException("The ess has to be non-negative.");
        }
        this.ess = ess;
        this.sumOfHyperParams = ess * (double)length;
        this.params = new MEMConstraint(new int[1], new int[]{(int)alphabets.getAlphabetLengthAt(0)});
        this.plugIn = plugIn;
        this.optimize = optimize;
        this.setFreeParams(false);
        this.norm = 1.0;
        double d = -Math.log(alphabets.getAlphabetLengthAt(0));
        int i = 0;
        while (i < this.counter.length) {
            this.params.setLambda(i, d);
            ++i;
        }
        this.computeConstantsOfLogPrior();
    }

    public HomogeneousMM0DiffSM(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public HomogeneousMM0DiffSM clone() throws CloneNotSupportedException {
        HomogeneousMM0DiffSM clone = (HomogeneousMM0DiffSM)super.clone();
        clone.params = this.params.clone();
        clone.counter = (int[])this.counter.clone();
        return clone;
    }

    @Override
    public String getInstanceName() {
        return "hMM(0)";
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start, int end) {
        double erg = 0.0;
        int length = end - start + 1;
        int l = 0;
        while (l < length) {
            erg += this.params.getLambda(this.params.satisfiesSpecificConstraint(seq, start + l));
            ++l;
        }
        return erg;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, int end, IntList indices, DoubleList dList) {
        Arrays.fill(this.counter, 0);
        int l = 0;
        int length = end - start + 1;
        while (l < length) {
            int n = this.params.satisfiesSpecificConstraint(seq, start + l);
            this.counter[n] = this.counter[n] + 1;
            ++l;
        }
        double erg = 0.0;
        l = 0;
        while (l < this.counter.length) {
            if (this.counter[l] > 0) {
                erg += (double)this.counter[l] * this.params.getLambda(l);
                if (l < this.anz) {
                    indices.add(l);
                    dList.add(this.counter[l]);
                }
            }
            ++l;
        }
        return erg;
    }

    @Override
    public int getNumberOfParameters() {
        return this.anz;
    }

    @Override
    public void setParameters(double[] params, int start) {
        if (this.optimize) {
            this.norm = 0.0;
            int i = 0;
            while (i < this.anz) {
                this.params.setLambda(i, params[start + i]);
                this.norm += this.params.getExpLambda(i);
                ++i;
            }
            if (this.anz < this.counter.length) {
                this.norm += this.params.getExpLambda(this.anz);
            }
        }
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer b = new StringBuffer(1000);
        XMLParser.appendObjectWithTags(b, this.length, "length");
        XMLParser.appendObjectWithTags(b, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(b, this.ess, "ess");
        XMLParser.appendObjectWithTags(b, this.sumOfHyperParams, "sumOfHyperParams");
        XMLParser.appendObjectWithTags(b, this.params, "params");
        XMLParser.appendObjectWithTags(b, this.freeParams, "freeParams");
        XMLParser.appendObjectWithTags(b, this.plugIn, "plugIn");
        XMLParser.appendObjectWithTags(b, this.optimize, "optimize");
        XMLParser.addTags(b, this.getClass().getSimpleName());
        return b;
    }

    @Override
    public double[] getCurrentParameterValues() {
        double[] erg = new double[this.anz];
        int i = 0;
        while (i < this.anz) {
            erg[i] = this.params.getLambda(i);
            ++i;
        }
        return erg;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) {
        this.params.reset();
        if (this.plugIn) {
            if (data != null && data[index] != null) {
                int i = 0;
                while (i < data[index].getNumberOfElements()) {
                    Sequence seq = data[index].getElementAt(i);
                    int l = seq.getLength();
                    int k = 0;
                    while (k < l) {
                        this.params.add(seq.discreteVal(k), weights[index][i]);
                        ++k;
                    }
                    ++i;
                }
            }
            this.params.estimate(this.sumOfHyperParams);
            int i = 0;
            while (i < this.counter.length) {
                this.params.setExpLambda(i, this.params.getFreq(i));
                ++i;
            }
        } else {
            double d = -Math.log(this.alphabets.getAlphabetLengthAt(0));
            int i = 0;
            while (i < this.counter.length) {
                this.params.setLambda(i, d);
                ++i;
            }
        }
        this.norm = 1.0;
        this.setFreeParams(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) {
        int n = this.counter.length;
        double[] p = DirichletMRG.DEFAULT_INSTANCE.generate(n, new FastDirichletMRGParams(this.sumOfHyperParams == 0.0 ? 1.0 : this.sumOfHyperParams / (double)n));
        int i = 0;
        while (i < n) {
            this.params.setExpLambda(i, p[i]);
            ++i;
        }
        this.norm = 1.0;
        this.setFreeParams(freeParams);
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        StringBuffer b = XMLParser.extractForTag(xml, this.getClass().getSimpleName());
        this.length = XMLParser.extractObjectForTags(b, "length", Integer.TYPE);
        this.alphabets = (AlphabetContainer)XMLParser.extractObjectForTags(b, "alphabets");
        this.ess = XMLParser.extractObjectForTags(b, "ess", Double.TYPE);
        this.sumOfHyperParams = XMLParser.extractObjectForTags(b, "sumOfHyperParams", Double.TYPE);
        this.params = XMLParser.extractObjectForTags(b, "params", MEMConstraint.class);
        this.plugIn = XMLParser.extractObjectForTags(b, "plugIn", Boolean.TYPE);
        this.optimize = XMLParser.extractObjectForTags(b, "optimize", Boolean.TYPE);
        this.setFreeParams(XMLParser.extractObjectForTags(b, "freeParams", Boolean.TYPE));
        int i = 0;
        while (i < this.params.getNumberOfSpecificConstraints()) {
            this.norm += this.params.getExpLambda(i);
            ++i;
        }
        this.computeConstantsOfLogPrior();
    }

    private void setFreeParams(boolean freeParams) {
        this.freeParams = freeParams;
        this.counter = new int[this.params.getNumberOfSpecificConstraints()];
        this.anz = this.optimize ? this.counter.length - (freeParams ? 1 : 0) : 0;
        if (freeParams) {
            double d = this.params.getLambda(this.params.getNumberOfSpecificConstraints() - 1);
            int i = 0;
            while (i < this.params.getNumberOfSpecificConstraints()) {
                this.params.setLambda(i, this.params.getLambda(i) - d);
                ++i;
            }
        }
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        if (index < this.anz) {
            return this.params.getNumberOfSpecificConstraints();
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public double getLogNormalizationConstant(int length) {
        if (length == 0) {
            throw new RuntimeException("The normalization constant can not be computed for length 0.");
        }
        return this.norm * (double)length;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex, int length) throws Exception {
        if (parameterIndex < this.anz) {
            return (double)length + this.norm * (double)(length - 1) + this.params.getLambda(parameterIndex);
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    public String toString() {
        StringBuffer info = new StringBuffer(100);
        info.append(String.valueOf(this.alphabets.getSymbol(0, 0.0)) + ": " + this.params.getExpLambda(0) / this.norm);
        int i = 1;
        while (i < this.params.getNumberOfSpecificConstraints()) {
            info.append("\t" + this.alphabets.getSymbol(0, i) + ": " + this.params.getExpLambda(i) / this.norm);
            ++i;
        }
        return info.toString();
    }

    @Override
    public double getLogPriorTerm() {
        if (this.optimize) {
            double val = 0.0;
            int n = this.params.getNumberOfSpecificConstraints();
            int i = 0;
            while (i < n) {
                val += this.params.getLambda(i++);
            }
            return val * this.sumOfHyperParams / (double)n + this.logGammaSum;
        }
        return 0.0;
    }

    private void computeConstantsOfLogPrior() {
        int anz = this.params.getNumberOfSpecificConstraints();
        this.logGammaSum = Gamma.logOfGamma(this.sumOfHyperParams) - (double)anz * Gamma.logOfGamma(this.sumOfHyperParams / (double)anz);
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) {
        double d = this.sumOfHyperParams / (double)this.params.getNumberOfSpecificConstraints();
        int i = 0;
        while (i < this.anz) {
            int n = start + i;
            grad[n] = grad[n] + d;
            ++i;
        }
    }

    @Override
    public boolean isInitialized() {
        return true;
    }

    @Override
    public byte getMaximalMarkovOrder() {
        return 0;
    }

    @Override
    public void setStatisticForHyperparameters(int[] length, double[] weight) throws Exception {
        if (weight.length != length.length) {
            throw new IllegalArgumentException("The length of both arrays (length, weight) have to be identical.");
        }
        this.sumOfHyperParams = 0.0;
        int i = 0;
        while (i < length.length) {
            if (weight[i] < 0.0 || length[i] < 0) {
                throw new IllegalArgumentException("check length and weight for entry " + i);
            }
            this.sumOfHyperParams += (double)length[i] * weight[i];
            ++i;
        }
        this.computeConstantsOfLogPrior();
    }

    @Override
    public void initializeUniformly(boolean freeParams) {
        double p = 1.0 / (double)this.counter.length;
        int i = 0;
        while (i < this.counter.length) {
            this.params.setExpLambda(i, p);
            ++i;
        }
        this.norm = 1.0;
        this.setFreeParams(freeParams);
    }

    @Override
    public int[][] getSamplingGroups(int parameterOffset) {
        int[][] res = new int[1][this.params.getNumberOfSpecificConstraints()];
        int i = 0;
        while (i < res[0].length) {
            res[0][i] = parameterOffset + i;
            ++i;
        }
        return res;
    }
}

