/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.mixture;

import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.sequenceScores.statisticalModels.differentiable.VariableLengthDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.MixtureDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;

public class VariableLengthMixtureDiffSM
extends MixtureDiffSM
implements VariableLengthDiffSM {
    public VariableLengthMixtureDiffSM(int starts, boolean plugIn, VariableLengthDiffSM ... component) throws CloneNotSupportedException {
        super(starts, plugIn, component);
    }

    public VariableLengthMixtureDiffSM(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start, int end) {
        int i = 0;
        while (i < this.function.length) {
            this.componentScore[i] = this.logHiddenPotential[i] + ((VariableLengthDiffSM)this.function[i]).getLogScoreFor(seq, start, end);
            ++i;
        }
        return Normalisation.getLogSum(this.componentScore);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, int end, IntList indices, DoubleList partialDer) {
        int i = 0;
        int j = 0;
        int k = this.paramRef.length - 1;
        k = this.paramRef[k] - this.paramRef[k - 1];
        while (i < this.function.length) {
            this.iList[i].clear();
            this.dList[i].clear();
            this.componentScore[i] = this.logHiddenPotential[i] + ((VariableLengthDiffSM)this.function[i]).getLogScoreAndPartialDerivation(seq, start, end, this.iList[i], this.dList[i]);
            ++i;
        }
        double logScore = Normalisation.logSumNormalisation(this.componentScore, 0, this.function.length, this.componentScore, 0);
        i = 0;
        while (i < this.function.length) {
            j = 0;
            while (j < this.iList[i].length()) {
                indices.add(this.paramRef[i] + this.iList[i].get(j));
                partialDer.add(this.componentScore[i] * this.dList[i].get(j));
                ++j;
            }
            ++i;
        }
        j = 0;
        while (j < k) {
            indices.add(this.paramRef[i] + j);
            partialDer.add(this.componentScore[j] - (this.isNormalized() ? this.hiddenPotential[j] : 0.0));
            ++j;
        }
        return logScore;
    }

    @Override
    public double getLogNormalizationConstant(int length) {
        double n = Double.NEGATIVE_INFINITY;
        int i = 0;
        while (i < this.logHiddenPotential.length) {
            n = Normalisation.getLogSum(n, this.logHiddenPotential[i] + ((VariableLengthDiffSM)this.function[i]).getLogNormalizationConstant(length));
            ++i;
        }
        return n;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex, int length) throws Exception {
        if (this.isNormalized()) {
            return Double.NEGATIVE_INFINITY;
        }
        int[] ind = this.getIndices(parameterIndex);
        if (ind[0] == this.function.length) {
            return this.logHiddenPotential[ind[1]] + ((VariableLengthDiffSM)this.function[ind[1]]).getLogNormalizationConstant(length);
        }
        return this.logHiddenPotential[ind[0]] + ((VariableLengthDiffSM)this.function[ind[0]]).getLogPartialNormalizationConstant(ind[1], length);
    }

    @Override
    public void setStatisticForHyperparameters(int[] length, double[] weight) throws Exception {
        double[] w = new double[this.getNumberOfComponents()];
        int i = 0;
        while (i < w.length) {
            w[i] = this.getHyperparameterForHiddenParameter(i);
            ++i;
        }
        Normalisation.sumNormalisation(w);
        double[] nw = new double[weight.length];
        int i2 = 0;
        while (i2 < this.function.length) {
            int j = 0;
            while (j < nw.length) {
                nw[j] = weight[j] * w[i2];
                ++j;
            }
            ((VariableLengthDiffSM)this.function[i2]).setStatisticForHyperparameters(length, nw);
            ++i2;
        }
    }
}

