/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.scoringFunctions.homogeneous;

import de.jstacs.NonParsableException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.ByteSequence;
import de.jstacs.io.XMLParser;
import de.jstacs.models.utils.StationaryDistribution;
import de.jstacs.scoringFunctions.homogeneous.HomogeneousScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.FastDirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.util.Arrays;
import java.util.Random;

public class HMMScoringFunction
extends HomogeneousScoringFunction {
    private boolean freeParams;
    private boolean plugIn;
    private boolean optimize;
    private int order;
    private int starts;
    private int[] powers;
    private double classEss;
    private double logGammaSum;
    private double[][] params;
    private double[][] probs;
    private double[][] logNorm;
    private double[] sumOfHyperParams;
    private int[] counter;
    private int[] distCounter;
    private int[] offset;

    public static double[] getSumOfHyperParameters(int order, int length, double ess) {
        double[] sumOfHyperParams = new double[order + 1];
        Arrays.fill(sumOfHyperParams, ess);
        sumOfHyperParams[order] = (double)(length - order) * ess;
        return sumOfHyperParams;
    }

    public HMMScoringFunction(AlphabetContainer alphabets, int order, double classEss, int length) {
        this(alphabets, order, classEss, HMMScoringFunction.getSumOfHyperParameters(order, length, classEss), true, true, 1);
    }

    public HMMScoringFunction(AlphabetContainer alphabets, int order, double classEss, double[] sumOfHyperParams, boolean plugIn, boolean optimize, int starts) {
        super(alphabets);
        if (order < 0) {
            throw new IllegalArgumentException("The order has to be non-negative.");
        }
        this.order = order;
        this.createArrays();
        if (classEss < 0.0) {
            throw new IllegalArgumentException("The ess for the class has to be non-negative.");
        }
        this.classEss = classEss;
        if (sumOfHyperParams == null) {
            sumOfHyperParams = new double[order + 1];
        } else {
            if (sumOfHyperParams.length != order + 1) {
                throw new IllegalArgumentException("Wrong dimension of the ess array.");
            }
            this.sumOfHyperParams = new double[order + 1];
            for (int i = 0; i <= order; ++i) {
                if (sumOfHyperParams[i] < 0.0) {
                    throw new IllegalArgumentException("The ess has to be non-negative. Violated at position " + i + ".");
                }
                if (i > 0 && i < order && sumOfHyperParams[i] > sumOfHyperParams[i - 1]) {
                    throw new IllegalArgumentException("The ess for start probabilities of order " + i + " is inconsistent with the ess for the probabilities of the previous order.");
                }
                this.sumOfHyperParams[i] = sumOfHyperParams[i];
            }
        }
        this.params = new double[order + 1][];
        double uniform = 1.0 / (double)this.powers[1];
        double logUniform = Math.log(uniform);
        for (int i = 0; i <= order; ++i) {
            this.params[i] = new double[this.powers[i + 1]];
            this.probs[i] = new double[this.powers[i + 1]];
            this.logNorm[i] = new double[this.powers[i]];
            Arrays.fill(this.params[i], logUniform);
            Arrays.fill(this.probs[i], uniform);
        }
        this.plugIn = plugIn;
        this.optimize = optimize;
        if (starts <= 0) {
            throw new IllegalArgumentException("The number of starts has to be positive.");
        }
        this.starts = starts;
        this.setFreeParams(false);
        this.computeConstantsOfLogPrior();
    }

    public HMMScoringFunction(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    private void createArrays() {
        this.powers = new int[this.order + 2];
        this.powers[0] = 1;
        this.powers[1] = (int)this.alphabets.getAlphabetLengthAt(0);
        for (int i = 2; i < this.powers.length; ++i) {
            this.powers[i] = this.powers[i - 1] * this.powers[1];
        }
        this.probs = new double[this.order + 1][];
        this.logNorm = new double[this.order + 1][];
        this.counter = new int[this.powers[this.order + 1]];
        this.distCounter = new int[this.powers[this.order]];
        this.offset = new int[this.order + 2];
    }

    @Override
    public HMMScoringFunction clone() throws CloneNotSupportedException {
        HMMScoringFunction clone = (HMMScoringFunction)super.clone();
        clone.params = new double[this.params.length][];
        clone.probs = new double[this.probs.length][];
        clone.logNorm = new double[this.logNorm.length][];
        for (int i = 0; i <= this.order; ++i) {
            clone.params[i] = (double[])this.params[i].clone();
            clone.probs[i] = (double[])this.probs[i].clone();
            clone.logNorm[i] = (double[])this.logNorm[i].clone();
        }
        clone.sumOfHyperParams = (double[])this.sumOfHyperParams.clone();
        clone.counter = (int[])this.counter.clone();
        clone.distCounter = (int[])this.distCounter.clone();
        clone.offset = (int[])this.offset.clone();
        return clone;
    }

    @Override
    public String getInstanceName() {
        return "hMM(" + this.order + ")";
    }

    @Override
    public double getLogScore(Sequence seq, int start, int length) {
        int indexOld;
        int l;
        double erg = 0.0;
        int indexNew = 0;
        int o = Math.min(this.order, length);
        for (l = 0; l < o; ++l) {
            indexOld = indexNew;
            indexNew = indexOld * this.powers[1] + seq.discreteVal(start++);
            erg += this.params[l][indexNew] - this.logNorm[l][indexOld];
        }
        while (l < length) {
            indexOld = indexNew % this.powers[this.order];
            indexNew = indexOld * this.powers[1] + seq.discreteVal(start++);
            erg += this.params[this.order][indexNew] - this.logNorm[this.order][indexOld];
            ++l;
        }
        return erg;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, int length, IntList indices, DoubleList dList) {
        if (this.optimize) {
            int index;
            int h;
            int indexOld;
            int l;
            Arrays.fill(this.counter, 0);
            Arrays.fill(this.distCounter, 0);
            double erg = 0.0;
            int stop = this.powers[1] - (this.freeParams ? 1 : 0);
            int indexNew = 0;
            int o = Math.min(this.order, length);
            for (l = 0; l < o; ++l) {
                indexOld = indexNew;
                int z = indexOld * this.powers[1];
                indexNew = z + seq.discreteVal(start++);
                erg += this.params[l][indexNew] - this.logNorm[l][indexOld];
                h = z - (this.freeParams ? indexOld : 0);
                for (index = 0; index < stop; ++index) {
                    indices.add(this.offset[l] + h + index);
                    if (z + index == indexNew) {
                        dList.add(1.0 - this.probs[l][z + index]);
                        continue;
                    }
                    dList.add(-this.probs[l][z + index]);
                }
            }
            while (l < length) {
                indexOld = indexNew % this.powers[this.order];
                indexNew = indexOld * this.powers[1] + seq.discreteVal(start++);
                int n = indexOld;
                this.distCounter[n] = this.distCounter[n] + 1;
                int n2 = indexNew;
                this.counter[n2] = this.counter[n2] + 1;
                ++l;
            }
            for (l = 0; l < this.distCounter.length; ++l) {
                if (this.distCounter[l] <= 0) continue;
                h = l * (this.powers[1] - (this.freeParams ? 1 : 0));
                o = l * this.powers[1];
                index = 0;
                while (index < stop) {
                    indices.add(this.offset[this.order] + h);
                    dList.add((double)this.counter[o] - (double)this.distCounter[l] * this.probs[this.order][o]);
                    erg += (double)this.counter[o] * this.params[this.order][o];
                    ++index;
                    ++h;
                    ++o;
                }
                if (stop < this.powers[1]) {
                    erg += (double)this.counter[o] * this.params[this.order][o];
                }
                erg -= (double)this.distCounter[l] * this.logNorm[this.order][l];
            }
            return erg;
        }
        return this.getLogScore(seq, start, length);
    }

    @Override
    public int getNumberOfParameters() {
        if (this.optimize) {
            return this.offset[this.order + 1];
        }
        return 0;
    }

    @Override
    public void setParameters(double[] params, int start) {
        if (this.optimize) {
            int stop = this.powers[1] - (this.freeParams ? 1 : 0);
            for (int o = 0; o <= this.order; ++o) {
                int index = 0;
                for (int n = 0; n < this.logNorm[o].length; ++n) {
                    this.logNorm[o][n] = 0.0;
                    int j = 0;
                    while (j < stop) {
                        this.params[o][index + j] = params[start];
                        this.probs[o][index + j] = Math.exp(this.params[o][index + j]);
                        double[] dArray = this.logNorm[o];
                        int n2 = n;
                        dArray[n2] = dArray[n2] + this.probs[o][index + j];
                        ++j;
                        ++start;
                    }
                    if (j < this.powers[1]) {
                        this.probs[o][index + j] = Math.exp(this.params[o][index + j]);
                        double[] dArray = this.logNorm[o];
                        int n3 = n;
                        dArray[n3] = dArray[n3] + this.probs[o][index + j];
                    }
                    for (j = 0; j < this.powers[1]; ++j) {
                        double[] dArray = this.probs[o];
                        int n4 = index++;
                        dArray[n4] = dArray[n4] / this.logNorm[o][n];
                    }
                    this.logNorm[o][n] = Math.log(this.logNorm[o][n]);
                }
            }
        }
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer b = new StringBuffer(1000);
        XMLParser.appendObjectWithTags(b, this.length, "length");
        XMLParser.appendObjectWithTags(b, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(b, this.order, "order");
        XMLParser.appendObjectWithTags(b, this.classEss, "classEss");
        XMLParser.appendObjectWithTags(b, this.sumOfHyperParams, "sumOfHyperParams");
        XMLParser.appendObjectWithTags(b, this.params, "params");
        XMLParser.appendObjectWithTags(b, this.plugIn, "plugIn");
        XMLParser.appendObjectWithTags(b, this.optimize, "optimize");
        XMLParser.appendObjectWithTags(b, this.starts, "starts");
        XMLParser.appendObjectWithTags(b, this.freeParams, "freeParams");
        XMLParser.addTags(b, this.getClass().getSimpleName());
        return b;
    }

    @Override
    public double[] getCurrentParameterValues() {
        int l = this.optimize ? this.offset[this.order + 1] : 0;
        double[] erg = new double[l];
        if (this.optimize) {
            int stop = this.powers[1] - (this.freeParams ? 1 : 0);
            int i = 0;
            for (int o = 0; o <= this.order; ++o) {
                for (int index = 0; index < this.params[o].length; index += this.powers[1]) {
                    int j = 0;
                    while (j < stop) {
                        erg[i] = this.params[o][index + j];
                        ++j;
                        ++i;
                    }
                }
            }
        }
        return erg;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, Sample[] data, double[][] weights) {
        if (this.optimize && this.plugIn && data != null && data[index] != null) {
            int indexOld;
            int indexNew;
            int o;
            int len;
            double hyper;
            for (int i = 0; i <= this.order; ++i) {
                hyper = this.sumOfHyperParams[i] / (double)this.probs[i].length;
                Arrays.fill(this.probs[i], hyper);
                Arrays.fill(this.logNorm[i], hyper * (double)this.powers[1]);
            }
            int anz = data[index].getNumberOfElements();
            double w = 1.0;
            boolean externalWeights = weights != null && weights[index] != null;
            for (int i = 0; i < anz; ++i) {
                int l;
                Sequence seq = data[index].getElementAt(i);
                len = seq.getLength();
                o = Math.min(len, this.order);
                indexNew = 0;
                if (externalWeights) {
                    w = weights[index][i];
                }
                for (l = 0; l < o; ++l) {
                    indexOld = indexNew;
                    indexNew = indexOld * this.powers[1] + seq.discreteVal(l);
                    double[] dArray = this.probs[l];
                    int n = indexNew;
                    dArray[n] = dArray[n] + w;
                    double[] dArray2 = this.logNorm[l];
                    int n2 = indexOld;
                    dArray2[n2] = dArray2[n2] + w;
                }
                while (l < len) {
                    indexOld = indexNew % this.powers[this.order];
                    indexNew = indexOld * this.powers[1] + seq.discreteVal(l);
                    double[] dArray = this.probs[this.order];
                    int n = indexNew;
                    dArray[n] = dArray[n] + w;
                    double[] dArray3 = this.logNorm[this.order];
                    int n3 = indexOld;
                    dArray3[n3] = dArray3[n3] + w;
                    ++l;
                }
            }
            for (o = 0; o <= this.order; ++o) {
                indexNew = 0;
                for (indexOld = 0; indexOld < this.logNorm[o].length; ++indexOld) {
                    if (this.logNorm[o][indexOld] > 0.0) {
                        len = 0;
                        while (len < this.powers[1]) {
                            double[] dArray = this.probs[o];
                            int n = indexNew;
                            dArray[n] = dArray[n] / this.logNorm[o][indexOld];
                            this.params[o][indexNew] = Math.log(this.probs[o][indexNew]);
                            ++len;
                            ++indexNew;
                        }
                        if (freeParams) {
                            int last = indexNew - 1;
                            indexNew -= this.powers[1];
                            for (len = 0; len < this.powers[1]; ++len) {
                                double[] dArray = this.params[o];
                                int n = indexNew++;
                                dArray[n] = dArray[n] - this.params[o][last];
                            }
                        }
                    } else {
                        hyper = 1.0 / (double)this.powers[1];
                        len = 0;
                        while (len < this.powers[1]) {
                            this.probs[o][indexNew] = hyper;
                            this.params[o][indexNew] = 0.0;
                            ++len;
                            ++indexNew;
                        }
                    }
                    this.logNorm[o][indexOld] = 0.0;
                }
            }
        } else {
            this.initializeFunctionRandomly(freeParams);
        }
        this.setFreeParams(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) {
        if (this.optimize) {
            double[] p = new double[this.powers[1]];
            double offset = 0.0;
            for (int o = 0; o <= this.order; ++o) {
                FastDirichletMRGParams hyper = new FastDirichletMRGParams(this.sumOfHyperParams[o] == 0.0 ? 1.0 : this.sumOfHyperParams[o] / (double)this.probs[o].length);
                int paramCounter = 0;
                for (int normCounter = 0; normCounter < this.logNorm[o].length; ++normCounter) {
                    this.logNorm[o][normCounter] = 0.0;
                    DirichletMRG.DEFAULT_INSTANCE.generate(p, 0, this.powers[1], hyper);
                    if (freeParams) {
                        offset = Math.log(p[this.powers[1] - 1]);
                    }
                    int len = 0;
                    while (len < this.powers[1]) {
                        this.probs[o][paramCounter] = p[len];
                        this.params[o][paramCounter] = Math.log(p[len]) - offset;
                        ++len;
                        ++paramCounter;
                    }
                }
            }
            this.setFreeParams(freeParams);
        }
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        StringBuffer b = XMLParser.extractForTag(xml, this.getClass().getSimpleName());
        this.length = XMLParser.extractObjectForTags(b, "length", Integer.TYPE);
        this.alphabets = XMLParser.extractObjectForTags(b, "alphabets", AlphabetContainer.class);
        this.order = XMLParser.extractObjectForTags(b, "order", Integer.TYPE);
        this.createArrays();
        this.classEss = XMLParser.extractObjectForTags(b, "classEss", Double.TYPE);
        this.sumOfHyperParams = XMLParser.extractObjectForTags(b, "sumOfHyperParams", double[].class);
        this.params = XMLParser.extractObjectForTags(b, "params", double[][].class);
        this.plugIn = XMLParser.extractObjectForTags(b, "plugIn", Boolean.TYPE);
        this.optimize = XMLParser.extractObjectForTags(b, "optimize", Boolean.TYPE);
        this.starts = XMLParser.extractObjectForTags(b, "starts", Integer.TYPE);
        this.setFreeParams(XMLParser.extractObjectForTags(b, "freeParams", Boolean.TYPE));
        for (int o = 0; o <= this.order; ++o) {
            this.probs[o] = new double[this.params[o].length];
            this.logNorm[o] = new double[this.powers[o]];
            int index = 0;
            for (int n = 0; n < this.logNorm[o].length; ++n) {
                int j;
                this.logNorm[o][n] = 0.0;
                for (j = 0; j < this.powers[1]; ++j) {
                    this.probs[o][index + j] = Math.exp(this.params[o][index + j]);
                    double[] dArray = this.logNorm[o];
                    int n2 = n;
                    dArray[n2] = dArray[n2] + this.probs[o][index + j];
                }
                for (j = 0; j < this.powers[1]; ++j) {
                    double[] dArray = this.probs[o];
                    int n3 = index++;
                    dArray[n3] = dArray[n3] / this.logNorm[o][n];
                }
                this.logNorm[o][n] = Math.log(this.logNorm[o][n]);
            }
        }
        this.computeConstantsOfLogPrior();
    }

    private void setFreeParams(boolean freeParams) {
        this.freeParams = freeParams;
        if (this.optimize) {
            this.offset[0] = 0;
            for (int i = 0; i <= this.order; ++i) {
                this.offset[i + 1] = this.offset[i] + this.params[i].length - (freeParams ? this.powers[i] : 0);
            }
        } else {
            this.offset[this.order + 1] = 0;
        }
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        if (index < this.offset[this.order + 1]) {
            return this.powers[1];
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public double getLogNormalizationConstant(int length) {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex, int length) throws Exception {
        if (parameterIndex < this.offset[this.order + 1]) {
            return Double.NEGATIVE_INFINITY;
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public double getEss() {
        return this.classEss;
    }

    public String toString() {
        int i;
        StringBuffer info = new StringBuffer(100);
        DiscreteAlphabet abc = (DiscreteAlphabet)this.alphabets.getAlphabetAt(0);
        int l = (int)abc.length();
        String[] sym = new String[l];
        --l;
        for (i = 0; i <= l; ++i) {
            sym[i] = abc.getSymbolAt(i);
            info.append("\t" + sym[i]);
        }
        info.append("\n");
        int[] context = new int[this.order + 1];
        for (int o = 0; o <= this.order; ++o) {
            info.append("P(X_" + o);
            for (i = 0; i < o; ++i) {
                if (i == 0) {
                    info.append("|");
                } else {
                    info.append(" ");
                }
                info.append("X_" + i);
            }
            info.append(")\n");
            Arrays.fill(context, 0);
            int index = 0;
            while (index < this.probs[o].length) {
                for (i = 0; i < o; ++i) {
                    info.append(sym[context[i]]);
                }
                i = 0;
                while (i <= l) {
                    info.append("\t" + this.probs[o][index]);
                    ++i;
                    ++index;
                }
                info.append("\n");
                for (i = o - 1; i >= 0 && context[i] == l; --i) {
                    context[i] = 0;
                }
                if (i < 0) continue;
                int n = i;
                context[n] = context[n] + 1;
            }
            info.append("\n");
        }
        return info.toString();
    }

    @Override
    public double getLogPriorTerm() {
        if (this.optimize) {
            double val = 0.0;
            for (int o = 0; o <= this.order; ++o) {
                double hyper = this.sumOfHyperParams[o] / (double)this.params[o].length;
                if (!(hyper > 0.0)) continue;
                double hyperSum = (double)this.powers[1] * hyper;
                int index = 0;
                for (int n = 0; n < this.logNorm[o].length; ++n) {
                    val -= hyperSum * this.logNorm[o][n];
                    int j = 0;
                    while (j < this.powers[1]) {
                        val += hyper * this.params[o][index];
                        ++j;
                        ++index;
                    }
                }
            }
            return val + this.logGammaSum;
        }
        return 0.0;
    }

    private void computeConstantsOfLogPrior() {
        this.logGammaSum = 0.0;
        for (int o = 0; o <= this.order; ++o) {
            double hyper = this.sumOfHyperParams[o] / (double)this.params[o].length;
            if (!(hyper > 0.0)) continue;
            this.logGammaSum += (double)this.logNorm[o].length * Gamma.logOfGamma((double)((double)this.powers[1] * hyper)) - (double)this.params[o].length * Gamma.logOfGamma((double)hyper);
        }
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) {
        if (this.optimize) {
            int stop = this.powers[1] - (this.freeParams ? 1 : 0);
            for (int o = 0; o <= this.order; ++o) {
                double hyper = this.sumOfHyperParams[o] / (double)this.params[o].length;
                double hyperSum = (double)this.powers[1] * hyper;
                for (int index = 0; index < this.params[o].length; index += this.powers[1]) {
                    for (int j = 0; j < stop; ++j) {
                        int n = start++;
                        grad[n] = grad[n] + (hyper - hyperSum * this.probs[o][index + j]);
                    }
                }
            }
        }
    }

    @Override
    public boolean isNormalized() {
        return true;
    }

    @Override
    public boolean isInitialized() {
        return true;
    }

    @Override
    public int getMaximalMarkovOrder() {
        return this.order;
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return this.starts;
    }

    public void setParameterOptimization(boolean optimize) {
        this.optimize = optimize;
    }

    public double[][][] getAllConditionalStationaryDistributions() {
        return StationaryDistribution.getAllConditionalStationaryDistributions(this.probs[this.order], this.powers[1]);
    }

    @Override
    public void setStatisticForHyperparameters(int[] length, double[] weight) throws Exception {
        if (weight.length != length.length) {
            throw new IllegalArgumentException("The length of both arrays (length, weight) have to be identical.");
        }
        Arrays.fill(this.sumOfHyperParams, 0.0);
        for (int i = 0; i < length.length; ++i) {
            if (weight[i] < 0.0 || length[i] < 0) {
                throw new IllegalArgumentException("check length and weight for entry " + i);
            }
            int l = 0;
            while (l < length[i] && l < this.order) {
                int n = l++;
                this.sumOfHyperParams[n] = this.sumOfHyperParams[n] + weight[i];
            }
            if (this.order >= length[i]) continue;
            int n = this.order;
            this.sumOfHyperParams[n] = this.sumOfHyperParams[n] + (double)(length[i] - this.order) * weight[i];
        }
        this.computeConstantsOfLogPrior();
    }

    public Sample emit(int numberOfSequences, int ... seqLength) throws Exception {
        Random r = new Random();
        Sequence[] seqs = new Sequence[numberOfSequences];
        int l = seqLength[0];
        for (int i = 0; i < numberOfSequences; ++i) {
            if (seqLength.length > 1) {
                l = seqLength[i];
            }
            byte[] bytes = new byte[l];
            int parent = 0;
            int o = 0;
            for (int j = 0; j < l; ++j) {
                int a;
                double p = r.nextDouble();
                for (a = 0; a < this.powers[1] && this.probs[o][parent + a] < p; p -= this.probs[o][parent + a], a = (int)((byte)(a + 1))) {
                }
                bytes[j] = a;
                parent += a;
                parent *= this.powers[1];
                parent %= this.powers[this.order + 1];
                if (o >= this.order) continue;
                ++o;
            }
            seqs[i] = new ByteSequence(this.alphabets, bytes);
        }
        return new Sample("generated from " + this.getInstanceName(), seqs);
    }

    @Override
    public void initializeUniformly(boolean freeParams) {
        double p = 1.0 / (double)this.powers[1];
        for (int o = 0; o <= this.order; ++o) {
            Arrays.fill(this.logNorm[o], 0.0);
            Arrays.fill(this.params[o], 0.0);
            Arrays.fill(this.probs[o], p);
        }
        this.setFreeParams(freeParams);
    }
}

